diff --git a/.asf.yaml b/.asf.yaml
index 22042b355b2fa..3935a525ff3c4 100644
--- a/.asf.yaml
+++ b/.asf.yaml
@@ -31,6 +31,8 @@ github:
merge: false
squash: true
rebase: true
+ ghp_branch: master
+ ghp_path: /docs
notifications:
pullrequests: reviews@spark.apache.org
diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 0bf7e57c364e4..2b459e4c73bbb 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -304,7 +304,7 @@ jobs:
uses: actions/upload-artifact@v4
with:
name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }}
- path: "**/target/unit-tests.log"
+ path: "**/target/*.log"
infra-image:
name: "Base image build"
@@ -723,7 +723,7 @@ jobs:
# See 'ipython_genutils' in SPARK-38517
# See 'docutils<0.18.0' in SPARK-39421
python3.9 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
- ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \
+ ipython ipython_genutils sphinx_plotly_directive 'numpy==1.26.4' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \
'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \
'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5'
diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml
index 3ac1a0117e41b..f668d813ef26e 100644
--- a/.github/workflows/build_python_connect.yml
+++ b/.github/workflows/build_python_connect.yml
@@ -71,7 +71,7 @@ jobs:
python packaging/connect/setup.py sdist
cd dist
pip install pyspark*connect-*.tar.gz
- pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting
+ pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8'
- name: Run tests
env:
SPARK_TESTING: 1
diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml
new file mode 100644
index 0000000000000..8faeb0557fbfb
--- /dev/null
+++ b/.github/workflows/pages.yml
@@ -0,0 +1,97 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+name: GitHub Pages deployment
+
+on:
+ push:
+ branches:
+ - master
+
+concurrency:
+ group: 'docs preview'
+ cancel-in-progress: false
+
+jobs:
+ docs:
+ name: Build and deploy documentation
+ runs-on: ubuntu-latest
+ permissions:
+ id-token: write
+ pages: write
+ environment:
+ name: github-pages # https://github.com/actions/deploy-pages/issues/271
+ env:
+ SPARK_TESTING: 1 # Reduce some noise in the logs
+ RELEASE_VERSION: 'In-Progress'
+ steps:
+ - name: Checkout Spark repository
+ uses: actions/checkout@v4
+ with:
+ repository: apache/spark
+ ref: 'master'
+ - name: Install Java 17
+ uses: actions/setup-java@v4
+ with:
+ distribution: zulu
+ java-version: 17
+ - name: Install Python 3.9
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.9'
+ architecture: x64
+ cache: 'pip'
+ - name: Install Python dependencies
+ run: |
+ pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
+ ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow 'pandas==2.2.2' 'plotly>=4.8' 'docutils<0.18.0' \
+ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \
+ 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
+ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5'
+ - name: Install Ruby for documentation generation
+ uses: ruby/setup-ruby@v1
+ with:
+ ruby-version: '3.3'
+ bundler-cache: true
+ - name: Install Pandoc
+ run: |
+ sudo apt-get update -y
+ sudo apt-get install pandoc
+ - name: Install dependencies for documentation generation
+ run: |
+ cd docs
+ gem install bundler -v 2.4.22 -n /usr/local/bin
+ bundle install --retry=100
+ - name: Run documentation build
+ run: |
+ sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml
+ sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml
+ sed -i".tmp3" "s/'facetFilters':.*$/'facetFilters': [\"version:$RELEASE_VERSION\"]/g" docs/_config.yml
+ sed -i".tmp4" 's/__version__: str = .*$/__version__: str = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py
+ cd docs
+ SKIP_RDOC=1 bundle exec jekyll build
+ - name: Setup Pages
+ uses: actions/configure-pages@v5
+ - name: Upload artifact
+ uses: actions/upload-pages-artifact@v3
+ with:
+ path: 'docs/_site'
+ - name: Deploy to GitHub Pages
+ id: deployment
+ uses: actions/deploy-pages@v4
diff --git a/.github/workflows/test_report.yml b/.github/workflows/test_report.yml
index c6225e6a1abe5..9ab69af42c818 100644
--- a/.github/workflows/test_report.yml
+++ b/.github/workflows/test_report.yml
@@ -30,14 +30,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Download test results to report
- uses: dawidd6/action-download-artifact@09385b76de790122f4da9c82b17bccf858b9557c # pin@v2
+ uses: dawidd6/action-download-artifact@bf251b5aa9c2f7eeb574a96ee720e24f801b7c11 # pin @v6
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
workflow: ${{ github.event.workflow_run.workflow_id }}
commit: ${{ github.event.workflow_run.head_commit.id }}
workflow_conclusion: completed
- name: Publish test report
- uses: scacap/action-surefire-report@482f012643ed0560e23ef605a79e8e87ca081648 # pin@v1
+ uses: scacap/action-surefire-report@a2911bd1a4412ec18dde2d93b1758b3e56d2a880 # pin @v1.8.0
with:
check_name: Report test results
github_token: ${{ secrets.GITHUB_TOKEN }}
diff --git a/.nojekyll b/.nojekyll
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/assembly/pom.xml b/assembly/pom.xml
index e5628ce90fa90..01bd324efc118 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -123,7 +123,7 @@
com.google.guava
@@ -200,7 +200,7 @@
cp
- ${basedir}/../connector/connect/client/jvm/target/spark-connect-client-jvm_${scala.binary.version}-${version}.jar
+ ${basedir}/../connector/connect/client/jvm/target/spark-connect-client-jvm_${scala.binary.version}-${project.version}.jar${basedir}/target/scala-${scala.binary.version}/jars/connect-repl
@@ -339,6 +339,14 @@
+
+
+ jjwt
+
+ compile
+
+
+
diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 5ed3048fb72b3..fb610a5d96f17 100644
--- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -109,7 +109,7 @@ private static int lowercaseMatchLengthFrom(
}
// Compare the characters in the target and pattern strings.
int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint;
- while (targetIterator.hasNext() && patternIterator.hasNext()) {
+ while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) {
if (codePointBuffer != -1) {
targetCodePoint = codePointBuffer;
codePointBuffer = -1;
@@ -211,7 +211,7 @@ private static int lowercaseMatchLengthUntil(
}
// Compare the characters in the target and pattern strings.
int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint;
- while (targetIterator.hasNext() && patternIterator.hasNext()) {
+ while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) {
if (codePointBuffer != -1) {
targetCodePoint = codePointBuffer;
codePointBuffer = -1;
diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 5640a2468d02e..d5dbca7eb89bc 100644
--- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -23,12 +23,14 @@
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.function.ToLongFunction;
+import java.util.stream.Stream;
+import com.ibm.icu.text.CollationKey;
+import com.ibm.icu.text.Collator;
import com.ibm.icu.text.RuleBasedCollator;
import com.ibm.icu.text.StringSearch;
import com.ibm.icu.util.ULocale;
-import com.ibm.icu.text.CollationKey;
-import com.ibm.icu.text.Collator;
+import com.ibm.icu.util.VersionInfo;
import org.apache.spark.SparkException;
import org.apache.spark.unsafe.types.UTF8String;
@@ -88,6 +90,17 @@ public Optional getVersion() {
}
}
+ public record CollationMeta(
+ String catalog,
+ String schema,
+ String collationName,
+ String language,
+ String country,
+ String icuVersion,
+ String padAttribute,
+ boolean accentSensitivity,
+ boolean caseSensitivity) { }
+
/**
* Entry encapsulating all information about a collation.
*/
@@ -342,6 +355,23 @@ private static int collationNameToId(String collationName) throws SparkException
}
protected abstract Collation buildCollation();
+
+ protected abstract CollationMeta buildCollationMeta();
+
+ static List listCollations() {
+ return Stream.concat(
+ CollationSpecUTF8.listCollations().stream(),
+ CollationSpecICU.listCollations().stream()).toList();
+ }
+
+ static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) {
+ CollationMeta collationSpecUTF8 =
+ CollationSpecUTF8.loadCollationMeta(collationIdentifier);
+ if (collationSpecUTF8 == null) {
+ return CollationSpecICU.loadCollationMeta(collationIdentifier);
+ }
+ return collationSpecUTF8;
+ }
}
private static class CollationSpecUTF8 extends CollationSpec {
@@ -364,6 +394,9 @@ private enum CaseSensitivity {
*/
private static final int CASE_SENSITIVITY_MASK = 0b1;
+ private static final String UTF8_BINARY_COLLATION_NAME = "UTF8_BINARY";
+ private static final String UTF8_LCASE_COLLATION_NAME = "UTF8_LCASE";
+
private static final int UTF8_BINARY_COLLATION_ID =
new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId;
private static final int UTF8_LCASE_COLLATION_ID =
@@ -406,7 +439,7 @@ private static CollationSpecUTF8 fromCollationId(int collationId) {
protected Collation buildCollation() {
if (collationId == UTF8_BINARY_COLLATION_ID) {
return new Collation(
- "UTF8_BINARY",
+ UTF8_BINARY_COLLATION_NAME,
PROVIDER_SPARK,
null,
UTF8String::binaryCompare,
@@ -417,7 +450,7 @@ protected Collation buildCollation() {
/* supportsLowercaseEquality = */ false);
} else {
return new Collation(
- "UTF8_LCASE",
+ UTF8_LCASE_COLLATION_NAME,
PROVIDER_SPARK,
null,
CollationAwareUTF8String::compareLowerCase,
@@ -428,6 +461,52 @@ protected Collation buildCollation() {
/* supportsLowercaseEquality = */ true);
}
}
+
+ @Override
+ protected CollationMeta buildCollationMeta() {
+ if (collationId == UTF8_BINARY_COLLATION_ID) {
+ return new CollationMeta(
+ CATALOG,
+ SCHEMA,
+ UTF8_BINARY_COLLATION_NAME,
+ /* language = */ null,
+ /* country = */ null,
+ /* icuVersion = */ null,
+ COLLATION_PAD_ATTRIBUTE,
+ /* accentSensitivity = */ true,
+ /* caseSensitivity = */ true);
+ } else {
+ return new CollationMeta(
+ CATALOG,
+ SCHEMA,
+ UTF8_LCASE_COLLATION_NAME,
+ /* language = */ null,
+ /* country = */ null,
+ /* icuVersion = */ null,
+ COLLATION_PAD_ATTRIBUTE,
+ /* accentSensitivity = */ true,
+ /* caseSensitivity = */ false);
+ }
+ }
+
+ static List listCollations() {
+ CollationIdentifier UTF8_BINARY_COLLATION_IDENT =
+ new CollationIdentifier(PROVIDER_SPARK, UTF8_BINARY_COLLATION_NAME, "1.0");
+ CollationIdentifier UTF8_LCASE_COLLATION_IDENT =
+ new CollationIdentifier(PROVIDER_SPARK, UTF8_LCASE_COLLATION_NAME, "1.0");
+ return Arrays.asList(UTF8_BINARY_COLLATION_IDENT, UTF8_LCASE_COLLATION_IDENT);
+ }
+
+ static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) {
+ try {
+ int collationId = CollationSpecUTF8.collationNameToId(
+ collationIdentifier.name, collationIdentifier.name.toUpperCase());
+ return CollationSpecUTF8.fromCollationId(collationId).buildCollationMeta();
+ } catch (SparkException ignored) {
+ // ignore
+ return null;
+ }
+ }
}
private static class CollationSpecICU extends CollationSpec {
@@ -684,6 +763,20 @@ protected Collation buildCollation() {
/* supportsLowercaseEquality = */ false);
}
+ @Override
+ protected CollationMeta buildCollationMeta() {
+ return new CollationMeta(
+ CATALOG,
+ SCHEMA,
+ collationName(),
+ ICULocaleMap.get(locale).getDisplayLanguage(),
+ ICULocaleMap.get(locale).getDisplayCountry(),
+ VersionInfo.ICU_VERSION.toString(),
+ COLLATION_PAD_ATTRIBUTE,
+ accentSensitivity == AccentSensitivity.AS,
+ caseSensitivity == CaseSensitivity.CS);
+ }
+
/**
* Compute normalized collation name. Components of collation name are given in order:
* - Locale name
@@ -704,6 +797,37 @@ private String collationName() {
}
return builder.toString();
}
+
+ private static List allCollationNames() {
+ List collationNames = new ArrayList<>();
+ for (String locale: ICULocaleToId.keySet()) {
+ // CaseSensitivity.CS + AccentSensitivity.AS
+ collationNames.add(locale);
+ // CaseSensitivity.CS + AccentSensitivity.AI
+ collationNames.add(locale + "_AI");
+ // CaseSensitivity.CI + AccentSensitivity.AS
+ collationNames.add(locale + "_CI");
+ // CaseSensitivity.CI + AccentSensitivity.AI
+ collationNames.add(locale + "_CI_AI");
+ }
+ return collationNames.stream().sorted().toList();
+ }
+
+ static List listCollations() {
+ return allCollationNames().stream().map(name ->
+ new CollationIdentifier(PROVIDER_ICU, name, VersionInfo.ICU_VERSION.toString())).toList();
+ }
+
+ static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) {
+ try {
+ int collationId = CollationSpecICU.collationNameToId(
+ collationIdentifier.name, collationIdentifier.name.toUpperCase());
+ return CollationSpecICU.fromCollationId(collationId).buildCollationMeta();
+ } catch (SparkException ignored) {
+ // ignore
+ return null;
+ }
+ }
}
/**
@@ -730,9 +854,12 @@ public CollationIdentifier identifier() {
}
}
+ public static final String CATALOG = "SYSTEM";
+ public static final String SCHEMA = "BUILTIN";
public static final String PROVIDER_SPARK = "spark";
public static final String PROVIDER_ICU = "icu";
public static final List SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU);
+ public static final String COLLATION_PAD_ATTRIBUTE = "NO_PAD";
public static final int UTF8_BINARY_COLLATION_ID =
Collation.CollationSpecUTF8.UTF8_BINARY_COLLATION_ID;
@@ -794,6 +921,18 @@ public static int collationNameToId(String collationName) throws SparkException
return Collation.CollationSpec.collationNameToId(collationName);
}
+ /**
+ * Returns whether the ICU collation is not Case Sensitive Accent Insensitive
+ * for the given collation id.
+ * This method is used in expressions which do not support CS_AI collations.
+ */
+ public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) {
+ return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity ==
+ Collation.CollationSpecICU.CaseSensitivity.CS &&
+ Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity ==
+ Collation.CollationSpecICU.AccentSensitivity.AI;
+ }
+
public static void assertValidProvider(String provider) throws SparkException {
if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) {
Map params = Map.of(
@@ -923,4 +1062,12 @@ public static String getClosestSuggestionsOnInvalidName(
return String.join(", ", suggestions);
}
+
+ public static List listCollations() {
+ return Collation.CollationSpec.listCollations();
+ }
+
+ public static CollationMeta loadCollationMeta(CollationIdentifier collationIdentifier) {
+ return Collation.CollationSpec.loadCollationMeta(collationIdentifier);
+ }
}
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index 5719303a0dce8..a445cde52ad57 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -629,6 +629,8 @@ public void testStartsWith() throws SparkException {
assertStartsWith("İonic", "Io", "UTF8_LCASE", false);
assertStartsWith("İonic", "i\u0307o", "UTF8_LCASE", true);
assertStartsWith("İonic", "İo", "UTF8_LCASE", true);
+ assertStartsWith("oİ", "oİ", "UTF8_LCASE", true);
+ assertStartsWith("oİ", "oi̇", "UTF8_LCASE", true);
// Conditional case mapping (e.g. Greek sigmas).
assertStartsWith("σ", "σ", "UTF8_BINARY", true);
assertStartsWith("σ", "ς", "UTF8_BINARY", false);
@@ -880,6 +882,8 @@ public void testEndsWith() throws SparkException {
assertEndsWith("the İo", "Io", "UTF8_LCASE", false);
assertEndsWith("the İo", "i\u0307o", "UTF8_LCASE", true);
assertEndsWith("the İo", "İo", "UTF8_LCASE", true);
+ assertEndsWith("İo", "İo", "UTF8_LCASE", true);
+ assertEndsWith("İo", "i̇o", "UTF8_LCASE", true);
// Conditional case mapping (e.g. Greek sigmas).
assertEndsWith("σ", "σ", "UTF8_BINARY", true);
assertEndsWith("σ", "ς", "UTF8_BINARY", false);
diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json
index e2725a98a63bd..e83202d9e5ee3 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1,4 +1,10 @@
{
+ "ADD_DEFAULT_UNSUPPORTED" : {
+ "message" : [
+ "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"."
+ ],
+ "sqlState" : "42623"
+ },
"AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION" : {
"message" : [
"Non-deterministic expression should not appear in the arguments of an aggregate function."
@@ -434,6 +440,12 @@
],
"sqlState" : "42846"
},
+ "CANNOT_USE_KRYO" : {
+ "message" : [
+ "Cannot load Kryo serialization codec. Kryo serialization cannot be used in the Spark Connect client. Use Java serialization, provide a custom Codec, or use Spark Classic instead."
+ ],
+ "sqlState" : "22KD3"
+ },
"CANNOT_WRITE_STATE_STORE" : {
"message" : [
"Error writing state store files for provider ."
@@ -449,13 +461,13 @@
},
"CAST_INVALID_INPUT" : {
"message" : [
- "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. If necessary set to \"false\" to bypass this error."
+ "The value of the type cannot be cast to because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead."
],
"sqlState" : "22018"
},
"CAST_OVERFLOW" : {
"message" : [
- "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set to \"false\" to bypass this error."
+ "The value of the type cannot be cast to due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead."
],
"sqlState" : "22003"
},
@@ -898,7 +910,7 @@
},
"NON_STRING_TYPE" : {
"message" : [
- "all arguments must be strings."
+ "all arguments of the function must be strings."
]
},
"NULL_TYPE" : {
@@ -1039,6 +1051,12 @@
],
"sqlState" : "42710"
},
+ "DATA_SOURCE_EXTERNAL_ERROR" : {
+ "message" : [
+ "Encountered error when saving to external data source."
+ ],
+ "sqlState" : "KD010"
+ },
"DATA_SOURCE_NOT_EXIST" : {
"message" : [
"Data source '' not found. Please make sure the data source is registered."
@@ -1084,6 +1102,12 @@
],
"sqlState" : "42608"
},
+ "DEFAULT_UNSUPPORTED" : {
+ "message" : [
+ "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"."
+ ],
+ "sqlState" : "42623"
+ },
"DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" : {
"message" : [
"Distinct window functions are not supported: ."
@@ -1432,6 +1456,12 @@
],
"sqlState" : "2203G"
},
+ "FAILED_TO_LOAD_ROUTINE" : {
+ "message" : [
+ "Failed to load routine ."
+ ],
+ "sqlState" : "38000"
+ },
"FAILED_TO_PARSE_TOO_COMPLEX" : {
"message" : [
"The statement, including potential SQL functions and referenced views, was too complex to parse.",
@@ -1457,6 +1487,12 @@
],
"sqlState" : "42704"
},
+ "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : {
+ "message" : [
+ "An error occurred in the user provided function in flatMapGroupsWithState. Reason: "
+ ],
+ "sqlState" : "39000"
+ },
"FORBIDDEN_OPERATION" : {
"message" : [
"The operation is not allowed on the : ."
@@ -1469,6 +1505,12 @@
],
"sqlState" : "39000"
},
+ "FOREACH_USER_FUNCTION_ERROR" : {
+ "message" : [
+ "An error occurred in the user provided function in foreach sink. Reason: "
+ ],
+ "sqlState" : "39000"
+ },
"FOUND_MULTIPLE_DATA_SOURCES" : {
"message" : [
"Detected multiple data sources with the name ''. Please check the data source isn't simultaneously registered and located in the classpath."
@@ -1565,6 +1607,36 @@
],
"sqlState" : "42601"
},
+ "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION" : {
+ "message" : [
+ "Duplicated IDENTITY column sequence generator option: ."
+ ],
+ "sqlState" : "42601"
+ },
+ "IDENTITY_COLUMNS_ILLEGAL_STEP" : {
+ "message" : [
+ "IDENTITY column step cannot be 0."
+ ],
+ "sqlState" : "42611"
+ },
+ "IDENTITY_COLUMNS_UNSUPPORTED_DATA_TYPE" : {
+ "message" : [
+ "DataType is not supported for IDENTITY columns."
+ ],
+ "sqlState" : "428H2"
+ },
+ "IDENTITY_COLUMN_WITH_DEFAULT_VALUE" : {
+ "message" : [
+ "A column cannot have both a default value and an identity column specification but column has default value: () and identity column specification: ()."
+ ],
+ "sqlState" : "42623"
+ },
+ "ILLEGAL_DAY_OF_WEEK" : {
+ "message" : [
+ "Illegal input for day of week: ."
+ ],
+ "sqlState" : "22009"
+ },
"ILLEGAL_STATE_STORE_VALUE" : {
"message" : [
"Illegal value provided to the State Store"
@@ -1942,6 +2014,12 @@
},
"sqlState" : "42903"
},
+ "INVALID_AGNOSTIC_ENCODER" : {
+ "message" : [
+ "Found an invalid agnostic encoder. Expects an instance of AgnosticEncoder but got . For more information consult '/api/java/index.html?org/apache/spark/sql/Encoder.html'."
+ ],
+ "sqlState" : "42001"
+ },
"INVALID_ARRAY_INDEX" : {
"message" : [
"The index is out of bounds. The array has elements. Use the SQL function `get()` to tolerate accessing element at invalid index and return NULL instead. If necessary set to \"false\" to bypass this error."
@@ -2074,6 +2152,11 @@
"message" : [
"Too many letters in datetime pattern: . Please reduce pattern length."
]
+ },
+ "SECONDS_FRACTION" : {
+ "message" : [
+ "Cannot detect a seconds fraction pattern of variable length. Please make sure the pattern contains 'S', and does not contain illegal characters."
+ ]
}
},
"sqlState" : "22007"
@@ -2372,6 +2455,11 @@
"Uncaught arithmetic exception while parsing ''."
]
},
+ "DAY_TIME_PARSING" : {
+ "message" : [
+ "Error parsing interval day-time string: ."
+ ]
+ },
"INPUT_IS_EMPTY" : {
"message" : [
"Interval string cannot be empty."
@@ -2382,6 +2470,11 @@
"Interval string cannot be null."
]
},
+ "INTERVAL_PARSING" : {
+ "message" : [
+ "Error parsing interval string."
+ ]
+ },
"INVALID_FRACTION" : {
"message" : [
" cannot have fractional part."
@@ -2417,15 +2510,35 @@
"Expect a unit name after but hit EOL."
]
},
+ "SECOND_NANO_FORMAT" : {
+ "message" : [
+ "Interval string does not match second-nano format of ss.nnnnnnnnn."
+ ]
+ },
"UNKNOWN_PARSING_ERROR" : {
"message" : [
"Unknown error when parsing ."
]
},
+ "UNMATCHED_FORMAT_STRING" : {
+ "message" : [
+ "Interval string does not match format of when cast to : ."
+ ]
+ },
+ "UNMATCHED_FORMAT_STRING_WITH_NOTICE" : {
+ "message" : [
+ "Interval string does not match format of when cast to : . Set \"spark.sql.legacy.fromDayTimeString.enabled\" to \"true\" to restore the behavior before Spark 3.0."
+ ]
+ },
"UNRECOGNIZED_NUMBER" : {
"message" : [
"Unrecognized number ."
]
+ },
+ "UNSUPPORTED_FROM_TO_EXPRESSION" : {
+ "message" : [
+ "Cannot support (interval '' to ) expression."
+ ]
}
},
"sqlState" : "22006"
@@ -2489,6 +2602,24 @@
],
"sqlState" : "F0000"
},
+ "INVALID_LABEL_USAGE" : {
+ "message" : [
+ "The usage of the label is invalid."
+ ],
+ "subClass" : {
+ "DOES_NOT_EXIST" : {
+ "message" : [
+ "Label was used in the statement, but the label does not belong to any surrounding block."
+ ]
+ },
+ "ITERATE_IN_COMPOUND" : {
+ "message" : [
+ "ITERATE statement cannot be used with a label that belongs to a compound (BEGIN...END) body."
+ ]
+ }
+ },
+ "sqlState" : "42K0L"
+ },
"INVALID_LAMBDA_FUNCTION_CALL" : {
"message" : [
"Invalid lambda function call."
@@ -3041,12 +3172,12 @@
"subClass" : {
"NOT_ALLOWED_IN_SCOPE" : {
"message" : [
- "Variable was declared on line , which is not allowed in this scope."
+ "Declaration of the variable is not allowed in this scope."
]
},
"ONLY_AT_BEGINNING" : {
"message" : [
- "Variable can only be declared at the beginning of the compound, but it was declared on line ."
+ "Variable can only be declared at the beginning of the compound."
]
}
},
@@ -3671,6 +3802,12 @@
],
"sqlState" : "42K03"
},
+ "PIPE_OPERATOR_SELECT_CONTAINS_AGGREGATE_FUNCTION" : {
+ "message" : [
+ "Aggregate function is not allowed when using the pipe operator |> SELECT clause; please use the pipe operator |> AGGREGATE clause instead"
+ ],
+ "sqlState" : "0A000"
+ },
"PIVOT_VALUE_DATA_TYPE_MISMATCH" : {
"message" : [
"Invalid pivot value '': value data type does not match pivot column data type ."
@@ -4326,6 +4463,24 @@
],
"sqlState" : "428EK"
},
+ "TRANSPOSE_EXCEED_ROW_LIMIT" : {
+ "message" : [
+ "Number of rows exceeds the allowed limit of for TRANSPOSE. If this was intended, set to at least the current row count."
+ ],
+ "sqlState" : "54006"
+ },
+ "TRANSPOSE_INVALID_INDEX_COLUMN" : {
+ "message" : [
+ "Invalid index column for TRANSPOSE because: "
+ ],
+ "sqlState" : "42804"
+ },
+ "TRANSPOSE_NO_LEAST_COMMON_TYPE" : {
+ "message" : [
+ "Transpose requires non-index columns to share a least common type, but and do not."
+ ],
+ "sqlState" : "42K09"
+ },
"UDTF_ALIAS_NUMBER_MISMATCH" : {
"message" : [
"The number of aliases supplied in the AS clause does not match the number of columns output by the UDTF.",
@@ -5199,6 +5354,11 @@
""
]
},
+ "SCALAR_SUBQUERY_IN_VALUES" : {
+ "message" : [
+ "Scalar subqueries in the VALUES clause."
+ ]
+ },
"UNSUPPORTED_CORRELATED_EXPRESSION_IN_JOIN_CONDITION" : {
"message" : [
"Correlated subqueries in the join predicate cannot reference both join inputs:",
@@ -6168,7 +6328,7 @@
"Detected implicit cartesian product for join between logical plans",
"",
"and",
- "rightPlan",
+ "",
"Join condition is missing or trivial.",
"Either: use the CROSS JOIN syntax to allow cartesian products between these relations, or: enable implicit cartesian products by setting the configuration variable spark.sql.crossJoin.enabled=true."
]
@@ -6531,21 +6691,6 @@
"Sinks cannot request distribution and ordering in continuous execution mode."
]
},
- "_LEGACY_ERROR_TEMP_1344" : {
- "message" : [
- "Invalid DEFAULT value for column : fails to parse as a valid literal value."
- ]
- },
- "_LEGACY_ERROR_TEMP_1345" : {
- "message" : [
- "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"."
- ]
- },
- "_LEGACY_ERROR_TEMP_1346" : {
- "message" : [
- "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"."
- ]
- },
"_LEGACY_ERROR_TEMP_2000" : {
"message" : [
". If necessary set to false to bypass this error."
@@ -6561,11 +6706,6 @@
"Type does not support ordered operations."
]
},
- "_LEGACY_ERROR_TEMP_2011" : {
- "message" : [
- "Unexpected data type ."
- ]
- },
"_LEGACY_ERROR_TEMP_2013" : {
"message" : [
"Negative values found in "
@@ -7773,7 +7913,7 @@
},
"_LEGACY_ERROR_TEMP_3055" : {
"message" : [
- "ScalarFunction '' neither implement magic method nor override 'produceResult'"
+ "ScalarFunction neither implement magic method nor override 'produceResult'"
]
},
"_LEGACY_ERROR_TEMP_3056" : {
@@ -8379,36 +8519,6 @@
"The number of fields () in the partition identifier is not equal to the partition schema length (). The identifier might not refer to one partition."
]
},
- "_LEGACY_ERROR_TEMP_3209" : {
- "message" : [
- "Illegal input for day of week: "
- ]
- },
- "_LEGACY_ERROR_TEMP_3210" : {
- "message" : [
- "Interval string does not match second-nano format of ss.nnnnnnnnn"
- ]
- },
- "_LEGACY_ERROR_TEMP_3211" : {
- "message" : [
- "Error parsing interval day-time string: "
- ]
- },
- "_LEGACY_ERROR_TEMP_3212" : {
- "message" : [
- "Cannot support (interval '' to ) expression"
- ]
- },
- "_LEGACY_ERROR_TEMP_3213" : {
- "message" : [
- "Error parsing interval string: "
- ]
- },
- "_LEGACY_ERROR_TEMP_3214" : {
- "message" : [
- "Interval string does not match format of when cast to : "
- ]
- },
"_LEGACY_ERROR_TEMP_3215" : {
"message" : [
"Expected a Boolean type expression in replaceNullWithFalse, but got the type in ."
diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json
index c369db3f65058..87811fef9836e 100644
--- a/common/utils/src/main/resources/error/error-states.json
+++ b/common/utils/src/main/resources/error/error-states.json
@@ -7417,6 +7417,12 @@
"standard": "N",
"usedBy": ["Databricks"]
},
+ "KD010": {
+ "description": "external data source failure",
+ "origin": "Databricks",
+ "standard": "N",
+ "usedBy": ["Databricks"]
+ },
"P0000": {
"description": "procedural logic error",
"origin": "PostgreSQL",
diff --git a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
index a1934dcf7a007..e2dd0da1aac85 100644
--- a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
+++ b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
@@ -19,7 +19,6 @@ package org.apache.spark
import java.net.URL
-import scala.collection.immutable.Map
import scala.jdk.CollectionConverters._
import com.fasterxml.jackson.annotation.JsonIgnore
@@ -52,7 +51,7 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
val sub = new StringSubstitutor(sanitizedParameters.asJava)
sub.setEnableUndefinedVariableException(true)
sub.setDisableSubstitutionInValues(true)
- try {
+ val errorMessage = try {
sub.replace(ErrorClassesJsonReader.TEMPLATE_REGEX.replaceAllIn(
messageTemplate, "\\$\\{$1\\}"))
} catch {
@@ -61,6 +60,17 @@ class ErrorClassesJsonReader(jsonFileURLs: Seq[URL]) {
s"MessageTemplate: $messageTemplate, " +
s"Parameters: $messageParameters", i)
}
+ if (util.SparkEnvUtils.isTesting) {
+ val placeHoldersNum = ErrorClassesJsonReader.TEMPLATE_REGEX.findAllIn(messageTemplate).length
+ if (placeHoldersNum < sanitizedParameters.size) {
+ throw SparkException.internalError(
+ s"Found unused message parameters of the error class '$errorClass'. " +
+ s"Its error message format has $placeHoldersNum placeholders, " +
+ s"but the passed message parameters map has ${sanitizedParameters.size} items. " +
+ "Consider to add placeholders to the error format or remove unused message parameters.")
+ }
+ }
+ errorMessage
}
def getMessageParameters(errorClass: String): Seq[String] = {
diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
index a7e4f186000b5..12d456a371d07 100644
--- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
+++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala
@@ -266,6 +266,7 @@ private[spark] object LogKeys {
case object FEATURE_NAME extends LogKey
case object FETCH_SIZE extends LogKey
case object FIELD_NAME extends LogKey
+ case object FIELD_TYPE extends LogKey
case object FILES extends LogKey
case object FILE_ABSOLUTE_PATH extends LogKey
case object FILE_END_OFFSET extends LogKey
@@ -652,6 +653,7 @@ private[spark] object LogKeys {
case object RECEIVER_IDS extends LogKey
case object RECORDS extends LogKey
case object RECOVERY_STATE extends LogKey
+ case object RECURSIVE_DEPTH extends LogKey
case object REDACTED_STATEMENT extends LogKey
case object REDUCE_ID extends LogKey
case object REGEX extends LogKey
diff --git a/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala
new file mode 100644
index 0000000000000..1f1d3492d6ac5
--- /dev/null
+++ b/common/utils/src/main/scala/org/apache/spark/scheduler/SparkListenerEvent.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler
+
+import com.fasterxml.jackson.annotation.JsonTypeInfo
+
+import org.apache.spark.annotation.DeveloperApi
+
+@DeveloperApi
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event")
+trait SparkListenerEvent {
+ /* Whether output this event to the event log */
+ protected[spark] def logEvent: Boolean = true
+}
diff --git a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala
index 42a1d1612aeeb..d54a2f2ed9cea 100644
--- a/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/MavenUtils.scala
@@ -342,7 +342,7 @@ private[spark] object MavenUtils extends Logging {
}
/* Set ivy settings for location of cache, if option is supplied */
- private def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = {
+ private[util] def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = {
val alternateIvyDir = ivyPath.filterNot(_.trim.isEmpty).getOrElse {
// To protect old Ivy-based systems like old Spark from Apache Ivy 2.5.2's incompatibility.
System.getProperty("ivy.home",
diff --git a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala
index 76062074edcaf..140de836622f4 100644
--- a/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala
+++ b/common/utils/src/test/scala/org/apache/spark/util/IvyTestUtils.scala
@@ -365,7 +365,7 @@ private[spark] object IvyTestUtils {
useIvyLayout: Boolean = false,
withPython: Boolean = false,
withR: Boolean = false,
- ivySettings: IvySettings = new IvySettings)(f: String => Unit): Unit = {
+ ivySettings: IvySettings = defaultIvySettings())(f: String => Unit): Unit = {
val deps = dependencies.map(MavenUtils.extractMavenCoordinates)
purgeLocalIvyCache(artifact, deps, ivySettings)
val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout,
@@ -401,4 +401,16 @@ private[spark] object IvyTestUtils {
}
}
}
+
+ /**
+ * Creates and initializes a new instance of IvySettings with default configurations.
+ * The method processes the Ivy path argument using MavenUtils to ensure proper setup.
+ *
+ * @return A newly created and configured instance of IvySettings.
+ */
+ private def defaultIvySettings(): IvySettings = {
+ val settings = new IvySettings
+ MavenUtils.processIvyPathArg(ivySettings = settings, ivyPath = None)
+ settings
+ }
}
diff --git a/common/variant/README.md b/common/variant/README.md
index a66d708da75bf..4ed7c16f5b6ed 100644
--- a/common/variant/README.md
+++ b/common/variant/README.md
@@ -333,27 +333,27 @@ The Decimal type contains a scale, but no precision. The implied precision of a
| Object | `2` | A collection of (string-key, variant-value) pairs |
| Array | `3` | An ordered sequence of variant values |
-| Primitive Type | Type ID | Equivalent Parquet Type | Binary format |
-|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------|
-| null | `0` | any | none |
-| boolean (True) | `1` | BOOLEAN | none |
-| boolean (False) | `2` | BOOLEAN | none |
-| int8 | `3` | INT(8, signed) | 1 byte |
-| int16 | `4` | INT(16, signed) | 2 byte little-endian |
-| int32 | `5` | INT(32, signed) | 4 byte little-endian |
-| int64 | `6` | INT(64, signed) | 8 byte little-endian |
-| double | `7` | DOUBLE | IEEE little-endian |
-| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) |
-| decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) |
-| decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) |
-| date | `11` | DATE | 4 byte little-endian |
-| timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian |
-| timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian |
-| float | `14` | FLOAT | IEEE little-endian |
-| binary | `15` | BINARY | 4 byte little-endian size, followed by bytes |
-| string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes |
-| year-month interval | `19` | INT(32, signed)1 | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. |
-| day-time interval | `20` | INT(64, signed)1 | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. |
+| Logical Type | Physical Type | Type ID | Equivalent Parquet Type | Binary format |
+|----------------------|-----------------------------|---------|-----------------------------|---------------------------------------------------------------------------------------------------------------------|
+| NullType | null | `0` | any | none |
+| Boolean | boolean (True) | `1` | BOOLEAN | none |
+| Boolean | boolean (False) | `2` | BOOLEAN | none |
+| Exact Numeric | int8 | `3` | INT(8, signed) | 1 byte |
+| Exact Numeric | int16 | `4` | INT(16, signed) | 2 byte little-endian |
+| Exact Numeric | int32 | `5` | INT(32, signed) | 4 byte little-endian |
+| Exact Numeric | int64 | `6` | INT(64, signed) | 8 byte little-endian |
+| Double | double | `7` | DOUBLE | IEEE little-endian |
+| Exact Numeric | decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) |
+| Exact Numeric | decimal8 | `9` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) |
+| Exact Numeric | decimal16 | `10` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) |
+| Date | date | `11` | DATE | 4 byte little-endian |
+| Timestamp | timestamp | `12` | TIMESTAMP(true, MICROS) | 8-byte little-endian |
+| TimestampNTZ | timestamp without time zone | `13` | TIMESTAMP(false, MICROS) | 8-byte little-endian |
+| Float | float | `14` | FLOAT | IEEE little-endian |
+| Binary | binary | `15` | BINARY | 4 byte little-endian size, followed by bytes |
+| String | string | `16` | STRING | 4 byte little-endian size, followed by UTF-8 encoded bytes |
+| YMInterval | year-month interval | `19` | INT(32, signed)1 | 1 byte denoting start field (1 bit) and end field (1 bit) starting at LSB followed by 4-byte little-endian value. |
+| DTInterval | day-time interval | `20` | INT(64, signed)1 | 1 byte denoting start field (2 bits) and end field (2 bits) starting at LSB followed by 8-byte little-endian value. |
| Decimal Precision | Decimal value type |
|-----------------------|--------------------|
@@ -362,6 +362,8 @@ The Decimal type contains a scale, but no precision. The implied precision of a
| 18 <= precision <= 38 | int128 |
| > 38 | Not supported |
+The *Logical Type* column indicates logical equivalence of physically encoded types. For example, a user expression operating on a string value containing "hello" should behave the same, whether it is encoded with the short string optimization, or long string encoding. Similarly, user expressions operating on an *int8* value of 1 should behave the same as a decimal16 with scale 2 and unscaled value 100.
+
The year-month and day-time interval types have one byte at the beginning indicating the start and end fields. In the case of the year-month interval, the least significant bit denotes the start field and the next least significant bit denotes the end field. The remaining 6 bits are unused. A field value of 0 represents YEAR and 1 represents MONTH. In the case of the day-time interval, the least significant 2 bits denote the start field and the next least significant 2 bits denote the end field. The remaining 4 bits are unused. A field value of 0 represents DAY, 1 represents HOUR, 2 represents MINUTE, and 3 represents SECOND.
Type IDs 17 and 18 were originally reserved for a prototype feature (string-from-metadata) that was never implemented. These IDs are available for use by new types.
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
index 7d80998d96eb1..0b85b208242cb 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala
@@ -42,7 +42,8 @@ private[sql] case class AvroDataToCatalyst(
val dt = SchemaConverters.toSqlType(
expectedSchema,
avroOptions.useStableIdForUnionType,
- avroOptions.stableIdPrefixForUnionType).dataType
+ avroOptions.stableIdPrefixForUnionType,
+ avroOptions.recursiveFieldMaxDepth).dataType
parseMode match {
// With PermissiveMode, the output Catalyst row might contain columns of null values for
// corrupt records, even if some of the columns are not nullable in the user-provided schema.
@@ -69,7 +70,8 @@ private[sql] case class AvroDataToCatalyst(
dataType,
avroOptions.datetimeRebaseModeInRead,
avroOptions.useStableIdForUnionType,
- avroOptions.stableIdPrefixForUnionType)
+ avroOptions.stableIdPrefixForUnionType,
+ avroOptions.recursiveFieldMaxDepth)
@transient private var decoder: BinaryDecoder = _
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index 877c3f89e88c0..ac20614553ca2 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -51,14 +51,16 @@ private[sql] class AvroDeserializer(
datetimeRebaseSpec: RebaseSpec,
filters: StructFilters,
useStableIdForUnionType: Boolean,
- stableIdPrefixForUnionType: String) {
+ stableIdPrefixForUnionType: String,
+ recursiveFieldMaxDepth: Int) {
def this(
rootAvroType: Schema,
rootCatalystType: DataType,
datetimeRebaseMode: String,
useStableIdForUnionType: Boolean,
- stableIdPrefixForUnionType: String) = {
+ stableIdPrefixForUnionType: String,
+ recursiveFieldMaxDepth: Int) = {
this(
rootAvroType,
rootCatalystType,
@@ -66,7 +68,8 @@ private[sql] class AvroDeserializer(
RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
new NoopFilters,
useStableIdForUnionType,
- stableIdPrefixForUnionType)
+ stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth)
}
private lazy val decimalConversions = new DecimalConversion()
@@ -128,7 +131,8 @@ private[sql] class AvroDeserializer(
s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})"
val realDataType = SchemaConverters.toSqlType(
- avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType
+ avroType, useStableIdForUnionType, stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth).dataType
(avroType.getType, catalystType) match {
case (NULL, NullType) => (updater, ordinal, _) =>
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
index 372f24b54f5c4..264c3a1f48abe 100755
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala
@@ -145,7 +145,8 @@ private[sql] class AvroFileFormat extends FileFormat
datetimeRebaseMode,
avroFilters,
parsedOptions.useStableIdForUnionType,
- parsedOptions.stableIdPrefixForUnionType)
+ parsedOptions.stableIdPrefixForUnionType,
+ parsedOptions.recursiveFieldMaxDepth)
override val stopPosition = file.start + file.length
override def hasNext: Boolean = hasNextRow
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
index 4332904339f19..e0c6ad3ee69d3 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode}
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
/**
@@ -136,6 +137,15 @@ private[sql] class AvroOptions(
val stableIdPrefixForUnionType: String = parameters
.getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_")
+
+ val recursiveFieldMaxDepth: Int =
+ parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1)
+
+ if (recursiveFieldMaxDepth > RECURSIVE_FIELD_MAX_DEPTH_LIMIT) {
+ throw QueryCompilationErrors.avroOptionsException(
+ RECURSIVE_FIELD_MAX_DEPTH,
+ s"Should not be greater than $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.")
+ }
}
private[sql] object AvroOptions extends DataSourceOptions {
@@ -170,4 +180,25 @@ private[sql] object AvroOptions extends DataSourceOptions {
// When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields
// of Avro Union type.
val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType")
+
+ /**
+ * Adds support for recursive fields. If this option is not specified or is set to 0, recursive
+ * fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive
+ * fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15.
+ * Values larger than 15 are not allowed in order to avoid inadvertently creating very large
+ * schemas. If an avro message has depth beyond this limit, the Spark struct returned is
+ * truncated after the recursion limit.
+ *
+ * Examples: Consider an Avro schema with a recursive field:
+ * {"type" : "record", "name" : "Node", "fields" : [{"name": "Id", "type": "int"},
+ * {"name": "Next", "type": ["null", "Node"]}]}
+ * The following lists the parsed schema with different values for this setting.
+ * 1: `struct`
+ * 2: `struct>`
+ * 3: `struct>>`
+ * and so on.
+ */
+ val RECURSIVE_FIELD_MAX_DEPTH = newOption("recursiveFieldMaxDepth")
+
+ val RECURSIVE_FIELD_MAX_DEPTH_LIMIT: Int = 15
}
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
index 7cbc30f1fb3dc..594ebb4716c41 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -65,7 +65,8 @@ private[sql] object AvroUtils extends Logging {
SchemaConverters.toSqlType(
avroSchema,
parsedOptions.useStableIdForUnionType,
- parsedOptions.stableIdPrefixForUnionType).dataType match {
+ parsedOptions.stableIdPrefixForUnionType,
+ parsedOptions.recursiveFieldMaxDepth).dataType match {
case t: StructType => Some(t)
case _ => throw new RuntimeException(
s"""Avro schema cannot be converted to a Spark SQL StructType:
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
index b2285aa966ddb..1168a887abd8e 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala
@@ -27,6 +27,10 @@ import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalT
import org.apache.avro.Schema.Type._
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.{FIELD_NAME, FIELD_TYPE, RECURSIVE_DEPTH}
+import org.apache.spark.internal.MDC
+import org.apache.spark.sql.avro.AvroOptions.RECURSIVE_FIELD_MAX_DEPTH_LIMIT
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.Decimal.minBytesForPrecision
@@ -36,7 +40,7 @@ import org.apache.spark.sql.types.Decimal.minBytesForPrecision
* versa.
*/
@DeveloperApi
-object SchemaConverters {
+object SchemaConverters extends Logging {
private lazy val nullSchema = Schema.create(Schema.Type.NULL)
/**
@@ -48,14 +52,27 @@ object SchemaConverters {
/**
* Converts an Avro schema to a corresponding Spark SQL schema.
- *
+ *
+ * @param avroSchema The Avro schema to convert.
+ * @param useStableIdForUnionType If true, Avro schema is deserialized into Spark SQL schema,
+ * and the Avro Union type is transformed into a structure where
+ * the field names remain consistent with their respective types.
+ * @param stableIdPrefixForUnionType The prefix to use to configure the prefix for fields of
+ * Avro Union type
+ * @param recursiveFieldMaxDepth The maximum depth to recursively process fields in Avro schema.
+ * -1 means not supported.
* @since 4.0.0
*/
def toSqlType(
avroSchema: Schema,
useStableIdForUnionType: Boolean,
- stableIdPrefixForUnionType: String): SchemaType = {
- toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType)
+ stableIdPrefixForUnionType: String,
+ recursiveFieldMaxDepth: Int = -1): SchemaType = {
+ val schema = toSqlTypeHelper(avroSchema, Map.empty, useStableIdForUnionType,
+ stableIdPrefixForUnionType, recursiveFieldMaxDepth)
+ // the top level record should never return null
+ assert(schema != null)
+ schema
}
/**
* Converts an Avro schema to a corresponding Spark SQL schema.
@@ -63,17 +80,17 @@ object SchemaConverters {
* @since 2.4.0
*/
def toSqlType(avroSchema: Schema): SchemaType = {
- toSqlType(avroSchema, false, "")
+ toSqlType(avroSchema, false, "", -1)
}
@deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0")
def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = {
val avroOptions = AvroOptions(options)
- toSqlTypeHelper(
+ toSqlType(
avroSchema,
- Set.empty,
avroOptions.useStableIdForUnionType,
- avroOptions.stableIdPrefixForUnionType)
+ avroOptions.stableIdPrefixForUnionType,
+ avroOptions.recursiveFieldMaxDepth)
}
// The property specifies Catalyst type of the given field
@@ -81,9 +98,10 @@ object SchemaConverters {
private def toSqlTypeHelper(
avroSchema: Schema,
- existingRecordNames: Set[String],
+ existingRecordNames: Map[String, Int],
useStableIdForUnionType: Boolean,
- stableIdPrefixForUnionType: String): SchemaType = {
+ stableIdPrefixForUnionType: String,
+ recursiveFieldMaxDepth: Int): SchemaType = {
avroSchema.getType match {
case INT => avroSchema.getLogicalType match {
case _: Date => SchemaType(DateType, nullable = false)
@@ -128,62 +146,110 @@ object SchemaConverters {
case NULL => SchemaType(NullType, nullable = true)
case RECORD =>
- if (existingRecordNames.contains(avroSchema.getFullName)) {
+ val recursiveDepth: Int = existingRecordNames.getOrElse(avroSchema.getFullName, 0)
+ if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) {
throw new IncompatibleSchemaException(s"""
- |Found recursive reference in Avro schema, which can not be processed by Spark:
- |${avroSchema.toString(true)}
+ |Found recursive reference in Avro schema, which can not be processed by Spark by
+ | default: ${avroSchema.toString(true)}. Try setting the option `recursiveFieldMaxDepth`
+ | to 1 - $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.
""".stripMargin)
- }
- val newRecordNames = existingRecordNames + avroSchema.getFullName
- val fields = avroSchema.getFields.asScala.map { f =>
- val schemaType = toSqlTypeHelper(
- f.schema(),
- newRecordNames,
- useStableIdForUnionType,
- stableIdPrefixForUnionType)
- StructField(f.name, schemaType.dataType, schemaType.nullable)
- }
+ } else if (recursiveDepth > 0 && recursiveDepth >= recursiveFieldMaxDepth) {
+ logInfo(
+ log"The field ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " +
+ log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} is dropped at recursive depth " +
+ log"${MDC(RECURSIVE_DEPTH, recursiveDepth)}."
+ )
+ null
+ } else {
+ val newRecordNames =
+ existingRecordNames + (avroSchema.getFullName -> (recursiveDepth + 1))
+ val fields = avroSchema.getFields.asScala.map { f =>
+ val schemaType = toSqlTypeHelper(
+ f.schema(),
+ newRecordNames,
+ useStableIdForUnionType,
+ stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth)
+ if (schemaType == null) {
+ null
+ }
+ else {
+ StructField(f.name, schemaType.dataType, schemaType.nullable)
+ }
+ }.filter(_ != null).toSeq
- SchemaType(StructType(fields.toArray), nullable = false)
+ SchemaType(StructType(fields), nullable = false)
+ }
case ARRAY =>
val schemaType = toSqlTypeHelper(
avroSchema.getElementType,
existingRecordNames,
useStableIdForUnionType,
- stableIdPrefixForUnionType)
- SchemaType(
- ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
- nullable = false)
+ stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth)
+ if (schemaType == null) {
+ logInfo(
+ log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " +
+ log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " +
+ log"fields left likely due to recursive depth limit."
+ )
+ null
+ } else {
+ SchemaType(
+ ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
+ nullable = false)
+ }
case MAP =>
val schemaType = toSqlTypeHelper(avroSchema.getValueType,
- existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType)
- SchemaType(
- MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
- nullable = false)
+ existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth)
+ if (schemaType == null) {
+ logInfo(
+ log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " +
+ log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " +
+ log"fields left likely due to recursive depth limit."
+ )
+ null
+ } else {
+ SchemaType(
+ MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
+ nullable = false)
+ }
case UNION =>
if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
- if (remainingUnionTypes.size == 1) {
- toSqlTypeHelper(
- remainingUnionTypes.head,
- existingRecordNames,
- useStableIdForUnionType,
- stableIdPrefixForUnionType).copy(nullable = true)
+ val remainingSchema =
+ if (remainingUnionTypes.size == 1) {
+ remainingUnionTypes.head
+ } else {
+ Schema.createUnion(remainingUnionTypes.asJava)
+ }
+ val schemaType = toSqlTypeHelper(
+ remainingSchema,
+ existingRecordNames,
+ useStableIdForUnionType,
+ stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth)
+
+ if (schemaType == null) {
+ logInfo(
+ log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " +
+ log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " +
+ log"fields left likely due to recursive depth limit."
+ )
+ null
} else {
- toSqlTypeHelper(
- Schema.createUnion(remainingUnionTypes.asJava),
- existingRecordNames,
- useStableIdForUnionType,
- stableIdPrefixForUnionType).copy(nullable = true)
+ schemaType.copy(nullable = true)
}
} else avroSchema.getTypes.asScala.map(_.getType).toSeq match {
case Seq(t1) =>
toSqlTypeHelper(avroSchema.getTypes.get(0),
- existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType)
+ existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth)
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
SchemaType(LongType, nullable = false)
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
@@ -201,29 +267,33 @@ object SchemaConverters {
s,
existingRecordNames,
useStableIdForUnionType,
- stableIdPrefixForUnionType)
-
- val fieldName = if (useStableIdForUnionType) {
- // Avro's field name may be case sensitive, so field names for two named type
- // could be "a" and "A" and we need to distinguish them. In this case, we throw
- // an exception.
- // Stable id prefix can be empty so the name of the field can be just the type.
- val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}"
- if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) {
- throw new IncompatibleSchemaException(
- "Cannot generate stable identifier for Avro union type due to name " +
- s"conflict of type name ${s.getName}")
- }
- tempFieldName
+ stableIdPrefixForUnionType,
+ recursiveFieldMaxDepth)
+ if (schemaType == null) {
+ null
} else {
- s"member$i"
- }
+ val fieldName = if (useStableIdForUnionType) {
+ // Avro's field name may be case sensitive, so field names for two named type
+ // could be "a" and "A" and we need to distinguish them. In this case, we throw
+ // an exception.
+ // Stable id prefix can be empty so the name of the field can be just the type.
+ val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}"
+ if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) {
+ throw new IncompatibleSchemaException(
+ "Cannot generate stable identifier for Avro union type due to name " +
+ s"conflict of type name ${s.getName}")
+ }
+ tempFieldName
+ } else {
+ s"member$i"
+ }
- // All fields are nullable because only one of them is set at a time
- StructField(fieldName, schemaType.dataType, nullable = true)
- }
+ // All fields are nullable because only one of them is set at a time
+ StructField(fieldName, schemaType.dataType, nullable = true)
+ }
+ }.filter(_ != null).toSeq
- SchemaType(StructType(fields.toArray), nullable = false)
+ SchemaType(StructType(fields), nullable = false)
}
case other => throw new IncompatibleSchemaException(s"Unsupported type $other")
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
index 1083c99160724..a13faf3b51560 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala
@@ -105,7 +105,8 @@ case class AvroPartitionReaderFactory(
datetimeRebaseMode,
avroFilters,
options.useStableIdForUnionType,
- options.stableIdPrefixForUnionType)
+ options.stableIdPrefixForUnionType,
+ options.recursiveFieldMaxDepth)
override val stopPosition = partitionedFile.start + partitionedFile.length
override def next(): Boolean = hasNextRow
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
index fe61fe3db8786..8ec711b2757f5 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala
@@ -37,7 +37,7 @@ case class AvroTable(
fallbackFileFormat: Class[_ <: FileFormat])
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
override def newScanBuilder(options: CaseInsensitiveStringMap): AvroScanBuilder =
- new AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
+ AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options))
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files)
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
index 388347537a4d6..311eda3a1b6ae 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala
@@ -291,7 +291,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
RebaseSpec(LegacyBehaviorPolicy.CORRECTED),
filters,
false,
- "")
+ "",
+ -1)
val deserialized = deserializer.deserialize(data)
expected match {
case None => assert(deserialized == None)
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
index 256b608feaa1f..0db9d284c4512 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala
@@ -54,7 +54,7 @@ class AvroCodecSuite extends FileSourceCodecSuite {
s"""CREATE TABLE avro_t
|USING $format OPTIONS('compression'='unsupported')
|AS SELECT 1 as id""".stripMargin)),
- errorClass = "CODEC_SHORT_NAME_NOT_FOUND",
+ condition = "CODEC_SHORT_NAME_NOT_FOUND",
sqlState = Some("42704"),
parameters = Map("codecName" -> "unsupported")
)
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
index 432c3fa9be3ac..a7f7abadcf485 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.functions.{col, lit, struct}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.{BinaryType, StructType}
+import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
import testImplicits._
@@ -329,7 +329,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
s"""
|select to_avro(s, 42) as result from t
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map("sqlExpr" -> "\"to_avro(s, 42)\"",
"msg" -> ("The second argument of the TO_AVRO SQL function must be a constant string " +
"containing the JSON representation of the schema to use for converting the value to " +
@@ -344,7 +344,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
s"""
|select from_avro(s, 42, '') as result from t
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map("sqlExpr" -> "\"from_avro(s, 42, )\"",
"msg" -> ("The second argument of the FROM_AVRO SQL function must be a constant string " +
"containing the JSON representation of the schema to use for converting the value " +
@@ -359,7 +359,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
s"""
|select from_avro(s, '$jsonFormatSchema', 42) as result from t
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" ->
s"\"from_avro(s, $jsonFormatSchema, 42)\"".stripMargin,
@@ -374,6 +374,37 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
}
}
+
+ test("roundtrip in to_avro and from_avro - recursive schema") {
+ val catalystSchema =
+ StructType(Seq(
+ StructField("Id", IntegerType),
+ StructField("Name", StructType(Seq(
+ StructField("Id", IntegerType),
+ StructField("Name", StructType(Seq(
+ StructField("Id", IntegerType)))))))))
+
+ val avroSchema = s"""
+ |{
+ | "type" : "record",
+ | "name" : "test_schema",
+ | "fields" : [
+ | {"name": "Id", "type": "int"},
+ | {"name": "Name", "type": ["null", "test_schema"]}
+ | ]
+ |}
+ """.stripMargin
+
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4))), Row(1, null))),
+ catalystSchema).select(struct("Id", "Name").as("struct"))
+
+ val avroStructDF = df.select(functions.to_avro($"struct", avroSchema).as("avro"))
+ checkAnswer(avroStructDF.select(
+ functions.from_avro($"avro", avroSchema, Map(
+ "recursiveFieldMaxDepth" -> "3").asJava)), df)
+ }
+
private def serialize(record: GenericRecord, avroSchema: String): Array[Byte] = {
val schema = new Schema.Parser().parse(avroSchema)
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
index 429f3c0deca6a..751ac275e048a 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
@@ -439,7 +439,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession {
assert(ex.getErrorClass.startsWith("FAILED_READ_FILE"))
checkError(
exception = ex.getCause.asInstanceOf[SparkArithmeticException],
- errorClass = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION",
+ condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION",
parameters = Map(
"value" -> "0",
"precision" -> "4",
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
index 9b3bb929a700d..c1ab96a63eb26 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala
@@ -77,7 +77,8 @@ class AvroRowReaderSuite
RebaseSpec(CORRECTED),
new NoopFilters,
false,
- "")
+ "",
+ -1)
override val stopPosition = fileSize
override def hasNext: Boolean = hasNextRow
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
index cbcbc2e7e76a6..3643a95abe19c 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala
@@ -228,7 +228,8 @@ object AvroSerdeSuite {
RebaseSpec(CORRECTED),
new NoopFilters,
false,
- "")
+ "",
+ -1)
}
/**
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index b20ee4b3cc231..be887bd5237b0 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -894,7 +894,7 @@ abstract class AvroSuite
assert(ex.getErrorClass.startsWith("FAILED_READ_FILE"))
checkError(
exception = ex.getCause.asInstanceOf[AnalysisException],
- errorClass = "AVRO_INCOMPATIBLE_READ_TYPE",
+ condition = "AVRO_INCOMPATIBLE_READ_TYPE",
parameters = Map("avroPath" -> "field 'a'",
"sqlPath" -> "field 'a'",
"avroType" -> "decimal\\(12,10\\)",
@@ -972,7 +972,7 @@ abstract class AvroSuite
assert(ex.getErrorClass.startsWith("FAILED_READ_FILE"))
checkError(
exception = ex.getCause.asInstanceOf[AnalysisException],
- errorClass = "AVRO_INCOMPATIBLE_READ_TYPE",
+ condition = "AVRO_INCOMPATIBLE_READ_TYPE",
parameters = Map("avroPath" -> "field 'a'",
"sqlPath" -> "field 'a'",
"avroType" -> "interval day to second",
@@ -1009,7 +1009,7 @@ abstract class AvroSuite
assert(ex.getErrorClass.startsWith("FAILED_READ_FILE"))
checkError(
exception = ex.getCause.asInstanceOf[AnalysisException],
- errorClass = "AVRO_INCOMPATIBLE_READ_TYPE",
+ condition = "AVRO_INCOMPATIBLE_READ_TYPE",
parameters = Map("avroPath" -> "field 'a'",
"sqlPath" -> "field 'a'",
"avroType" -> "interval year to month",
@@ -1673,7 +1673,7 @@ abstract class AvroSuite
exception = intercept[AnalysisException] {
sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir)
},
- errorClass = "_LEGACY_ERROR_TEMP_1136",
+ condition = "_LEGACY_ERROR_TEMP_1136",
parameters = Map.empty
)
checkError(
@@ -1681,7 +1681,7 @@ abstract class AvroSuite
spark.udf.register("testType", () => new IntervalData())
sql("select testType()").write.format("avro").mode("overwrite").save(tempDir)
},
- errorClass = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE",
+ condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE",
parameters = Map(
"columnName" -> "`testType()`",
"columnType" -> "UDT(\"INTERVAL\")",
@@ -2220,7 +2220,8 @@ abstract class AvroSuite
}
}
- private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = {
+ private def checkSchemaWithRecursiveLoop(avroSchema: String, recursiveFieldMaxDepth: Int):
+ Unit = {
val message = intercept[IncompatibleSchemaException] {
SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "")
}.getMessage
@@ -2229,7 +2230,79 @@ abstract class AvroSuite
}
test("Detect recursive loop") {
- checkSchemaWithRecursiveLoop("""
+ for (recursiveFieldMaxDepth <- Seq(-1, 0)) {
+ checkSchemaWithRecursiveLoop(
+ """
+ |{
+ | "type": "record",
+ | "name": "LongList",
+ | "fields" : [
+ | {"name": "value", "type": "long"}, // each element has a long
+ | {"name": "next", "type": ["null", "LongList"]} // optional next element
+ | ]
+ |}
+ """.stripMargin, recursiveFieldMaxDepth)
+
+ checkSchemaWithRecursiveLoop(
+ """
+ |{
+ | "type": "record",
+ | "name": "LongList",
+ | "fields": [
+ | {
+ | "name": "value",
+ | "type": {
+ | "type": "record",
+ | "name": "foo",
+ | "fields": [
+ | {
+ | "name": "parent",
+ | "type": "LongList"
+ | }
+ | ]
+ | }
+ | }
+ | ]
+ |}
+ """.stripMargin, recursiveFieldMaxDepth)
+
+ checkSchemaWithRecursiveLoop(
+ """
+ |{
+ | "type": "record",
+ | "name": "LongList",
+ | "fields" : [
+ | {"name": "value", "type": "long"},
+ | {"name": "array", "type": {"type": "array", "items": "LongList"}}
+ | ]
+ |}
+ """.stripMargin, recursiveFieldMaxDepth)
+
+ checkSchemaWithRecursiveLoop(
+ """
+ |{
+ | "type": "record",
+ | "name": "LongList",
+ | "fields" : [
+ | {"name": "value", "type": "long"},
+ | {"name": "map", "type": {"type": "map", "values": "LongList"}}
+ | ]
+ |}
+ """.stripMargin, recursiveFieldMaxDepth)
+ }
+ }
+
+ private def checkSparkSchemaEquals(
+ avroSchema: String, expectedSchema: StructType, recursiveFieldMaxDepth: Int): Unit = {
+ val sparkSchema =
+ SchemaConverters.toSqlType(
+ new Schema.Parser().parse(avroSchema), false, "", recursiveFieldMaxDepth).dataType
+
+ assert(sparkSchema === expectedSchema)
+ }
+
+ test("Translate recursive schema - union") {
+ val avroSchema = """
|{
| "type": "record",
| "name": "LongList",
@@ -2238,9 +2311,57 @@ abstract class AvroSuite
| {"name": "next", "type": ["null", "LongList"]} // optional next element
| ]
|}
- """.stripMargin)
+ """.stripMargin
+ val nonRecursiveFields = new StructType().add("value", LongType, nullable = false)
+ var expectedSchema = nonRecursiveFields
+ for (i <- 1 to 5) {
+ checkSparkSchemaEquals(avroSchema, expectedSchema, i)
+ expectedSchema = nonRecursiveFields.add("next", expectedSchema)
+ }
+ }
+
+ test("Translate recursive schema - union - 2 non-null fields") {
+ val avroSchema = """
+ |{
+ | "type": "record",
+ | "name": "TreeNode",
+ | "fields": [
+ | {
+ | "name": "name",
+ | "type": "string"
+ | },
+ | {
+ | "name": "value",
+ | "type": [
+ | "long"
+ | ]
+ | },
+ | {
+ | "name": "children",
+ | "type": [
+ | "null",
+ | {
+ | "type": "array",
+ | "items": "TreeNode"
+ | }
+ | ],
+ | "default": null
+ | }
+ | ]
+ |}
+ """.stripMargin
+ val nonRecursiveFields = new StructType().add("name", StringType, nullable = false)
+ .add("value", LongType, nullable = false)
+ var expectedSchema = nonRecursiveFields
+ for (i <- 1 to 5) {
+ checkSparkSchemaEquals(avroSchema, expectedSchema, i)
+ expectedSchema = nonRecursiveFields.add("children",
+ new ArrayType(expectedSchema, false), nullable = true)
+ }
+ }
- checkSchemaWithRecursiveLoop("""
+ test("Translate recursive schema - record") {
+ val avroSchema = """
|{
| "type": "record",
| "name": "LongList",
@@ -2260,9 +2381,18 @@ abstract class AvroSuite
| }
| ]
|}
- """.stripMargin)
+ """.stripMargin
+ val nonRecursiveFields = new StructType().add("value", StructType(Seq()), nullable = false)
+ var expectedSchema = nonRecursiveFields
+ for (i <- 1 to 5) {
+ checkSparkSchemaEquals(avroSchema, expectedSchema, i)
+ expectedSchema = new StructType().add("value",
+ new StructType().add("parent", expectedSchema, nullable = false), nullable = false)
+ }
+ }
- checkSchemaWithRecursiveLoop("""
+ test("Translate recursive schema - array") {
+ val avroSchema = """
|{
| "type": "record",
| "name": "LongList",
@@ -2271,9 +2401,18 @@ abstract class AvroSuite
| {"name": "array", "type": {"type": "array", "items": "LongList"}}
| ]
|}
- """.stripMargin)
+ """.stripMargin
+ val nonRecursiveFields = new StructType().add("value", LongType, nullable = false)
+ var expectedSchema = nonRecursiveFields
+ for (i <- 1 to 5) {
+ checkSparkSchemaEquals(avroSchema, expectedSchema, i)
+ expectedSchema =
+ nonRecursiveFields.add("array", new ArrayType(expectedSchema, false), nullable = false)
+ }
+ }
- checkSchemaWithRecursiveLoop("""
+ test("Translate recursive schema - map") {
+ val avroSchema = """
|{
| "type": "record",
| "name": "LongList",
@@ -2282,7 +2421,70 @@ abstract class AvroSuite
| {"name": "map", "type": {"type": "map", "values": "LongList"}}
| ]
|}
- """.stripMargin)
+ """.stripMargin
+ val nonRecursiveFields = new StructType().add("value", LongType, nullable = false)
+ var expectedSchema = nonRecursiveFields
+ for (i <- 1 to 5) {
+ checkSparkSchemaEquals(avroSchema, expectedSchema, i)
+ expectedSchema =
+ nonRecursiveFields.add("map",
+ new MapType(StringType, expectedSchema, false), nullable = false)
+ }
+ }
+
+ test("recursive schema integration test") {
+ val catalystSchema =
+ StructType(Seq(
+ StructField("Id", IntegerType),
+ StructField("Name", StructType(Seq(
+ StructField("Id", IntegerType),
+ StructField("Name", StructType(Seq(
+ StructField("Id", IntegerType),
+ StructField("Name", NullType)))))))))
+
+ val avroSchema = s"""
+ |{
+ | "type" : "record",
+ | "name" : "test_schema",
+ | "fields" : [
+ | {"name": "Id", "type": "int"},
+ | {"name": "Name", "type": ["null", "test_schema"]}
+ | ]
+ |}
+ """.stripMargin
+
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4, null))), Row(1, null))),
+ catalystSchema)
+
+ withTempPath { tempDir =>
+ df.write.format("avro").save(tempDir.getPath)
+
+ val exc = intercept[AnalysisException] {
+ spark.read
+ .format("avro")
+ .option("avroSchema", avroSchema)
+ .option("recursiveFieldMaxDepth", 16)
+ .load(tempDir.getPath)
+ }
+ assert(exc.getMessage.contains("Should not be greater than 15."))
+
+ checkAnswer(
+ spark.read
+ .format("avro")
+ .option("avroSchema", avroSchema)
+ .option("recursiveFieldMaxDepth", 10)
+ .load(tempDir.getPath),
+ df)
+
+ checkAnswer(
+ spark.read
+ .format("avro")
+ .option("avroSchema", avroSchema)
+ .option("recursiveFieldMaxDepth", 1)
+ .load(tempDir.getPath),
+ df.select("Id"))
+ }
}
test("log a warning of ignoreExtension deprecation") {
@@ -2726,7 +2928,7 @@ abstract class AvroSuite
|LOCATION '${dir}'
|AS SELECT ID, IF(ID=1,1,0) FROM v""".stripMargin)
},
- errorClass = "INVALID_COLUMN_NAME_AS_PATH",
+ condition = "INVALID_COLUMN_NAME_AS_PATH",
parameters = Map(
"datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`")
)
@@ -2777,7 +2979,7 @@ abstract class AvroSuite
}
test("SPARK-40667: validate Avro Options") {
- assert(AvroOptions.getAllOptions.size == 11)
+ assert(AvroOptions.getAllOptions.size == 12)
// Please add validation on any new Avro options here
assert(AvroOptions.isValidOption("ignoreExtension"))
assert(AvroOptions.isValidOption("mode"))
@@ -2790,6 +2992,7 @@ abstract class AvroSuite
assert(AvroOptions.isValidOption("datetimeRebaseMode"))
assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType"))
assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType"))
+ assert(AvroOptions.isValidOption("recursiveFieldMaxDepth"))
}
test("SPARK-46633: read file with empty blocks") {
@@ -2831,7 +3034,7 @@ class AvroV1Suite extends AvroSuite {
sql("SELECT ID, IF(ID=1,1,0) FROM v").write.mode(SaveMode.Overwrite)
.format("avro").save(dir.getCanonicalPath)
},
- errorClass = "INVALID_COLUMN_NAME_AS_PATH",
+ condition = "INVALID_COLUMN_NAME_AS_PATH",
parameters = Map(
"datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`")
)
@@ -2844,7 +3047,7 @@ class AvroV1Suite extends AvroSuite {
.write.mode(SaveMode.Overwrite)
.format("avro").save(dir.getCanonicalPath)
},
- errorClass = "INVALID_COLUMN_NAME_AS_PATH",
+ condition = "INVALID_COLUMN_NAME_AS_PATH",
parameters = Map(
"datasource" -> "AvroFileFormat", "columnName" -> "`(IF((ID = 1), 1, 0))`")
)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index c06cbbc0cdb42..3777f82594aae 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
+import org.apache.spark.sql.connect.ConnectConversions._
/**
* Functionality for working with missing data in `DataFrame`s.
@@ -29,7 +30,7 @@ import org.apache.spark.connect.proto.NAReplace.Replacement
* @since 3.4.0
*/
final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation)
- extends api.DataFrameNaFunctions[Dataset] {
+ extends api.DataFrameNaFunctions {
import sparkSession.RichColumn
override protected def drop(minNonNulls: Option[Int]): Dataset[Row] =
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 1ad98dc91b216..60bacd4e18ede 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -23,11 +23,8 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Stable
import org.apache.spark.connect.proto.Parse.ParseFormat
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils}
+import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
-import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.types.StructType
/**
@@ -37,144 +34,44 @@ import org.apache.spark.sql.types.StructType
* @since 3.4.0
*/
@Stable
-class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging {
-
- /**
- * Specifies the input data source format.
- *
- * @since 3.4.0
- */
- def format(source: String): DataFrameReader = {
- this.source = source
- this
- }
+class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader {
+ type DS[U] = Dataset[U]
- /**
- * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema
- * automatically from data. By specifying the schema here, the underlying data source can skip
- * the schema inference step, and thus speed up data loading.
- *
- * @since 3.4.0
- */
- def schema(schema: StructType): DataFrameReader = {
- if (schema != null) {
- val replaced = SparkCharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
- this.userSpecifiedSchema = Option(replaced)
- }
- this
- }
+ /** @inheritdoc */
+ override def format(source: String): this.type = super.format(source)
- /**
- * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON)
- * can infer the input schema automatically from data. By specifying the schema here, the
- * underlying data source can skip the schema inference step, and thus speed up data loading.
- *
- * {{{
- * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv")
- * }}}
- *
- * @since 3.4.0
- */
- def schema(schemaString: String): DataFrameReader = {
- schema(StructType.fromDDL(schemaString))
- }
+ /** @inheritdoc */
+ override def schema(schema: StructType): this.type = super.schema(schema)
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names. If a new option
- * has the same key case-insensitively, it will override the existing option.
- *
- * @since 3.4.0
- */
- def option(key: String, value: String): DataFrameReader = {
- this.extraOptions = this.extraOptions + (key -> value)
- this
- }
+ /** @inheritdoc */
+ override def schema(schemaString: String): this.type = super.schema(schemaString)
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names. If a new option
- * has the same key case-insensitively, it will override the existing option.
- *
- * @since 3.4.0
- */
- def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString)
-
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names. If a new option
- * has the same key case-insensitively, it will override the existing option.
- *
- * @since 3.4.0
- */
- def option(key: String, value: Long): DataFrameReader = option(key, value.toString)
-
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names. If a new option
- * has the same key case-insensitively, it will override the existing option.
- *
- * @since 3.4.0
- */
- def option(key: String, value: Double): DataFrameReader = option(key, value.toString)
-
- /**
- * (Scala-specific) Adds input options for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names. If a new option
- * has the same key case-insensitively, it will override the existing option.
- *
- * @since 3.4.0
- */
- def options(options: scala.collection.Map[String, String]): DataFrameReader = {
- this.extraOptions ++= options
- this
- }
+ /** @inheritdoc */
+ override def option(key: String, value: String): this.type = super.option(key, value)
- /**
- * Adds input options for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names. If a new option
- * has the same key case-insensitively, it will override the existing option.
- *
- * @since 3.4.0
- */
- def options(options: java.util.Map[String, String]): DataFrameReader = {
- this.options(options.asScala)
- this
- }
+ /** @inheritdoc */
+ override def option(key: String, value: Boolean): this.type = super.option(key, value)
- /**
- * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
- * key-value stores).
- *
- * @since 3.4.0
- */
- def load(): DataFrame = {
- load(Seq.empty: _*) // force invocation of `load(...varargs...)`
- }
+ /** @inheritdoc */
+ override def option(key: String, value: Long): this.type = super.option(key, value)
- /**
- * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a
- * local or distributed file system).
- *
- * @since 3.4.0
- */
- def load(path: String): DataFrame = {
- // force invocation of `load(...varargs...)`
- load(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def option(key: String, value: Double): this.type = super.option(key, value)
+
+ /** @inheritdoc */
+ override def options(options: scala.collection.Map[String, String]): this.type =
+ super.options(options)
- /**
- * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if
- * the source is a HadoopFsRelationProvider.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
+ override def options(options: java.util.Map[String, String]): this.type = super.options(options)
+
+ /** @inheritdoc */
+ override def load(): DataFrame = load(Nil: _*)
+
+ /** @inheritdoc */
+ def load(path: String): DataFrame = load(Seq(path): _*)
+
+ /** @inheritdoc */
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
sparkSession.newDataFrame { builder =>
@@ -190,93 +87,29 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
}
}
- /**
- * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
- * table and connection properties.
- *
- * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
- * in
- * Data Source Option in the version you use.
- *
- * @since 3.4.0
- */
- def jdbc(url: String, table: String, properties: Properties): DataFrame = {
- // properties should override settings in extraOptions.
- this.extraOptions ++= properties.asScala
- // explicit url and dbtable should override all
- this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
- format("jdbc").load()
- }
+ /** @inheritdoc */
+ override def jdbc(url: String, table: String, properties: Properties): DataFrame =
+ super.jdbc(url, table, properties)
- // scalastyle:off line.size.limit
- /**
- * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
- * table. Partitions of the table will be retrieved in parallel based on the parameters passed
- * to this function.
- *
- * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
- * your external database systems.
- *
- * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
- * in
- * Data Source Option in the version you use.
- *
- * @param table
- * Name of the table in the external database.
- * @param columnName
- * Alias of `partitionColumn` option. Refer to `partitionColumn` in
- * Data Source Option in the version you use.
- * @param connectionProperties
- * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
- * a "user" and "password" property should be included. "fetchsize" can be used to control the
- * number of rows per fetch and "queryTimeout" can be used to wait for a Statement object to
- * execute to the given number of seconds.
- * @since 3.4.0
- */
- // scalastyle:on line.size.limit
- def jdbc(
+ /** @inheritdoc */
+ override def jdbc(
url: String,
table: String,
columnName: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
- connectionProperties: Properties): DataFrame = {
- // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions.
- this.extraOptions ++= Map(
- "partitionColumn" -> columnName,
- "lowerBound" -> lowerBound.toString,
- "upperBound" -> upperBound.toString,
- "numPartitions" -> numPartitions.toString)
- jdbc(url, table, connectionProperties)
- }
-
- /**
- * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
- * table using connection properties. The `predicates` parameter gives a list expressions
- * suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`.
- *
- * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
- * your external database systems.
- *
- * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
- * in
- * Data Source Option in the version you use.
- *
- * @param table
- * Name of the table in the external database.
- * @param predicates
- * Condition in the where clause for each partition.
- * @param connectionProperties
- * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
- * a "user" and "password" property should be included. "fetchsize" can be used to control the
- * number of rows per fetch.
- * @since 3.4.0
- */
+ connectionProperties: Properties): DataFrame =
+ super.jdbc(
+ url,
+ table,
+ columnName,
+ lowerBound,
+ upperBound,
+ numPartitions,
+ connectionProperties)
+
+ /** @inheritdoc */
def jdbc(
url: String,
table: String,
@@ -296,207 +129,56 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
}
}
- /**
- * Loads a JSON file and returns the results as a `DataFrame`.
- *
- * See the documentation on the overloaded `json()` method with varargs for more details.
- *
- * @since 3.4.0
- */
- def json(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- json(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def json(path: String): DataFrame = super.json(path)
- /**
- * Loads JSON files and returns the results as a `DataFrame`.
- *
- * JSON Lines (newline-delimited JSON) is supported by
- * default. For JSON (one record per file), set the `multiLine` option to true.
- *
- * This function goes through the input once to determine the input schema. If you know the
- * schema in advance, use the version that specifies the schema to avoid the extra scan.
- *
- * You can find the JSON-specific options for reading JSON files in
- * Data Source Option in the version you use.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def json(paths: String*): DataFrame = {
- format("json").load(paths: _*)
- }
+ override def json(paths: String*): DataFrame = super.json(paths: _*)
- /**
- * Loads a `Dataset[String]` storing JSON objects (JSON Lines
- * text format or newline-delimited JSON) and returns the result as a `DataFrame`.
- *
- * Unless the schema is specified using `schema` function, this function goes through the input
- * once to determine the input schema.
- *
- * @param jsonDataset
- * input Dataset with one JSON object per record
- * @since 3.4.0
- */
+ /** @inheritdoc */
def json(jsonDataset: Dataset[String]): DataFrame =
parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON)
- /**
- * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
- * overloaded `csv()` method for more details.
- *
- * @since 3.4.0
- */
- def csv(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- csv(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def csv(path: String): DataFrame = super.csv(path)
- /**
- * Loads CSV files and returns the result as a `DataFrame`.
- *
- * This function will go through the input once to determine the input schema if `inferSchema`
- * is enabled. To avoid going through the entire data once, disable `inferSchema` option or
- * specify the schema explicitly using `schema`.
- *
- * You can find the CSV-specific options for reading CSV files in
- * Data Source Option in the version you use.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def csv(paths: String*): DataFrame = format("csv").load(paths: _*)
-
- /**
- * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
- *
- * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
- * this function goes through the input once to determine the input schema.
- *
- * If the schema is not specified using `schema` function and `inferSchema` option is disabled,
- * it determines the columns as string types and it reads only the first line to determine the
- * names and the number of fields.
- *
- * If the enforceSchema is set to `false`, only the CSV header in the first line is checked to
- * conform specified or inferred schema.
- *
- * @note
- * if `header` option is set to `true` when calling this API, all lines same with the header
- * will be removed if exists.
- * @param csvDataset
- * input Dataset with one CSV row per record
- * @since 3.4.0
- */
+ override def csv(paths: String*): DataFrame = super.csv(paths: _*)
+
+ /** @inheritdoc */
def csv(csvDataset: Dataset[String]): DataFrame =
parse(csvDataset, ParseFormat.PARSE_FORMAT_CSV)
- /**
- * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other
- * overloaded `xml()` method for more details.
- *
- * @since 4.0.0
- */
- def xml(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- xml(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def xml(path: String): DataFrame = super.xml(path)
- /**
- * Loads XML files and returns the result as a `DataFrame`.
- *
- * This function will go through the input once to determine the input schema if `inferSchema`
- * is enabled. To avoid going through the entire data once, disable `inferSchema` option or
- * specify the schema explicitly using `schema`.
- *
- * You can find the XML-specific options for reading XML files in
- * Data Source Option in the version you use.
- *
- * @since 4.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def xml(paths: String*): DataFrame = format("xml").load(paths: _*)
-
- /**
- * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`.
- *
- * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
- * this function goes through the input once to determine the input schema.
- *
- * @param xmlDataset
- * input Dataset with one XML object per record
- * @since 4.0.0
- */
+ override def xml(paths: String*): DataFrame = super.xml(paths: _*)
+
+ /** @inheritdoc */
def xml(xmlDataset: Dataset[String]): DataFrame =
parse(xmlDataset, ParseFormat.PARSE_FORMAT_UNSPECIFIED)
- /**
- * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the
- * other overloaded `parquet()` method for more details.
- *
- * @since 3.4.0
- */
- def parquet(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- parquet(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def parquet(path: String): DataFrame = super.parquet(path)
- /**
- * Loads a Parquet file, returning the result as a `DataFrame`.
- *
- * Parquet-specific option(s) for reading Parquet files can be found in Data
- * Source Option in the version you use.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def parquet(paths: String*): DataFrame = {
- format("parquet").load(paths: _*)
- }
+ override def parquet(paths: String*): DataFrame = super.parquet(paths: _*)
- /**
- * Loads an ORC file and returns the result as a `DataFrame`.
- *
- * @param path
- * input path
- * @since 3.4.0
- */
- def orc(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- orc(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def orc(path: String): DataFrame = super.orc(path)
- /**
- * Loads ORC files and returns the result as a `DataFrame`.
- *
- * ORC-specific option(s) for reading ORC files can be found in Data
- * Source Option in the version you use.
- *
- * @param paths
- * input paths
- * @since 3.4.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def orc(paths: String*): DataFrame = format("orc").load(paths: _*)
-
- /**
- * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch
- * reading and the returned DataFrame is the batch scan query plan of this table. If it's a
- * view, the returned DataFrame is simply the query plan of the view, which can either be a
- * batch or streaming query plan.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table or view. If a database is
- * specified, it identifies the table/view from the database. Otherwise, it first attempts to
- * find a temporary view with the given name and then match the table/view from the current
- * database. Note that, the global temporary view database is also valid here.
- * @since 3.4.0
- */
+ override def orc(paths: String*): DataFrame = super.orc(paths: _*)
+
+ /** @inheritdoc */
def table(tableName: String): DataFrame = {
+ assertNoSpecifiedSchema("table")
sparkSession.newDataFrame { builder =>
builder.getReadBuilder.getNamedTableBuilder
.setUnparsedIdentifier(tableName)
@@ -504,80 +186,19 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
}
}
- /**
- * Loads text files and returns a `DataFrame` whose schema starts with a string column named
- * "value", and followed by partitioned columns if there are any. See the documentation on the
- * other overloaded `text()` method for more details.
- *
- * @since 3.4.0
- */
- def text(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- text(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def text(path: String): DataFrame = super.text(path)
- /**
- * Loads text files and returns a `DataFrame` whose schema starts with a string column named
- * "value", and followed by partitioned columns if there are any. The text files must be encoded
- * as UTF-8.
- *
- * By default, each line in the text files is a new row in the resulting DataFrame. For example:
- * {{{
- * // Scala:
- * spark.read.text("/path/to/spark/README.md")
- *
- * // Java:
- * spark.read().text("/path/to/spark/README.md")
- * }}}
- *
- * You can find the text-specific options for reading text files in
- * Data Source Option in the version you use.
- *
- * @param paths
- * input paths
- * @since 3.4.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def text(paths: String*): DataFrame = format("text").load(paths: _*)
-
- /**
- * Loads text files and returns a [[Dataset]] of String. See the documentation on the other
- * overloaded `textFile()` method for more details.
- * @since 3.4.0
- */
- def textFile(path: String): Dataset[String] = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- textFile(Seq(path): _*)
- }
+ override def text(paths: String*): DataFrame = super.text(paths: _*)
- /**
- * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
- * contains a single string column named "value". The text files must be encoded as UTF-8.
- *
- * If the directory structure of the text files contains partitioning information, those are
- * ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
- *
- * By default, each line in the text files is a new row in the resulting DataFrame. For example:
- * {{{
- * // Scala:
- * spark.read.textFile("/path/to/spark/README.md")
- *
- * // Java:
- * spark.read().textFile("/path/to/spark/README.md")
- * }}}
- *
- * You can set the text-specific options as specified in `DataFrameReader.text`.
- *
- * @param paths
- * input path
- * @since 3.4.0
- */
+ /** @inheritdoc */
+ override def textFile(path: String): Dataset[String] = super.textFile(path)
+
+ /** @inheritdoc */
@scala.annotation.varargs
- def textFile(paths: String*): Dataset[String] = {
- assertNoSpecifiedSchema("textFile")
- text(paths: _*).select("value").as(StringEncoder)
- }
+ override def textFile(paths: String*): Dataset[String] = super.textFile(paths: _*)
private def assertSourceFormatSpecified(): Unit = {
if (source == null) {
@@ -597,24 +218,4 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
}
}
}
-
- /**
- * A convenient function for schema validation in APIs.
- */
- private def assertNoSpecifiedSchema(operation: String): Unit = {
- if (userSpecifiedSchema.nonEmpty) {
- throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation)
- }
- }
-
- ///////////////////////////////////////////////////////////////////////////////////////
- // Builder pattern config options
- ///////////////////////////////////////////////////////////////////////////////////////
-
- private var source: String = _
-
- private var userSpecifiedSchema: Option[StructType] = None
-
- private var extraOptions = CaseInsensitiveMap[String](Map.empty)
-
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 9f5ada0d7ec35..bb7cfa75a9ab9 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,6 +22,7 @@ import java.{lang => jl, util => ju}
import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder}
+import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.functions.lit
/**
@@ -30,7 +31,7 @@ import org.apache.spark.sql.functions.lit
* @since 3.4.0
*/
final class DataFrameStatFunctions private[sql] (protected val df: DataFrame)
- extends api.DataFrameStatFunctions[Dataset] {
+ extends api.DataFrameStatFunctions {
private def root: Relation = df.plan.getRoot
private val sparkSession: SparkSession = df.sparkSession
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index ce21f18501a79..accfff9f2b073 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -32,12 +32,13 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
+import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils}
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.{struct, to_json}
-import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, UnresolvedAttribute, UnresolvedRegex}
+import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
@@ -134,8 +135,8 @@ class Dataset[T] private[sql] (
val sparkSession: SparkSession,
@DeveloperApi val plan: proto.Plan,
val encoder: Encoder[T])
- extends api.Dataset[T, Dataset] {
- type RGD = RelationalGroupedDataset
+ extends api.Dataset[T] {
+ type DS[U] = Dataset[U]
import sparkSession.RichColumn
@@ -481,7 +482,7 @@ class Dataset[T] private[sql] (
val unpivot = builder.getUnpivotBuilder
.setInput(plan.getRoot)
.addAllIds(ids.toImmutableArraySeq.map(_.expr).asJava)
- .setValueColumnName(variableColumnName)
+ .setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName)
valuesOption.foreach { values =>
unpivot.getValuesBuilder
@@ -489,6 +490,14 @@ class Dataset[T] private[sql] (
}
}
+ private def buildTranspose(indices: Seq[Column]): DataFrame =
+ sparkSession.newDataFrame { builder =>
+ val transpose = builder.getTransposeBuilder.setInput(plan.getRoot)
+ indices.foreach { indexColumn =>
+ transpose.addIndexColumns(indexColumn.expr)
+ }
+ }
+
/** @inheritdoc */
@scala.annotation.varargs
def groupBy(cols: Column*): RelationalGroupedDataset = {
@@ -515,27 +524,11 @@ class Dataset[T] private[sql] (
result(0)
}
- /**
- * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
- * key `func`.
- *
- * @group typedrel
- * @since 3.5.0
- */
+ /** @inheritdoc */
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func)
}
- /**
- * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
- * key `func`.
- *
- * @group typedrel
- * @since 3.5.0
- */
- def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
- groupByKey(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
-
/** @inheritdoc */
@scala.annotation.varargs
def rollup(cols: Column*): RelationalGroupedDataset = {
@@ -582,6 +575,14 @@ class Dataset[T] private[sql] (
buildUnpivot(ids, None, variableColumnName, valueColumnName)
}
+ /** @inheritdoc */
+ def transpose(indexColumn: Column): DataFrame =
+ buildTranspose(Seq(indexColumn))
+
+ /** @inheritdoc */
+ def transpose(): DataFrame =
+ buildTranspose(Seq.empty)
+
/** @inheritdoc */
def limit(n: Int): Dataset[T] = sparkSession.newDataset(agnosticEncoder) { builder =>
builder.getLimitBuilder
@@ -865,17 +866,17 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
def filter(f: FilterFunction[T]): Dataset[T] = {
- filter(UdfUtils.filterFuncToScalaFunc(f))
+ filter(ToScalaUDF(f))
}
/** @inheritdoc */
def map[U: Encoder](f: T => U): Dataset[U] = {
- mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f))
+ mapPartitions(UDFAdaptors.mapToMapPartitions(f))
}
/** @inheritdoc */
def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- map(UdfUtils.mapFunctionToScalaFunc(f))(encoder)
+ mapPartitions(UDFAdaptors.mapToMapPartitions(f))(encoder)
}
/** @inheritdoc */
@@ -892,25 +893,11 @@ class Dataset[T] private[sql] (
}
}
- /** @inheritdoc */
- def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder)
- }
-
- /** @inheritdoc */
- override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] =
- mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func))
-
- /** @inheritdoc */
- override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
- }
-
/** @inheritdoc */
@deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0")
def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = {
val generator = SparkUserDefinedFunction(
- UdfUtils.iterableOnceToSeq(f),
+ UDFAdaptors.iterableOnceToSeq(f),
UnboundRowEncoder :: Nil,
ScalaReflection.encoderFor[Seq[A]])
select(col("*"), functions.inline(generator(struct(input: _*))))
@@ -921,31 +908,16 @@ class Dataset[T] private[sql] (
def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(
f: A => IterableOnce[B]): DataFrame = {
val generator = SparkUserDefinedFunction(
- UdfUtils.iterableOnceToSeq(f),
+ UDFAdaptors.iterableOnceToSeq(f),
Nil,
ScalaReflection.encoderFor[Seq[B]])
select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn)))
}
- /** @inheritdoc */
- def foreach(f: T => Unit): Unit = {
- foreachPartition(UdfUtils.foreachFuncToForeachPartitionsAdaptor(f))
- }
-
- /** @inheritdoc */
- override def foreach(func: ForeachFunction[T]): Unit =
- foreach(UdfUtils.foreachFuncToScalaFunc(func))
-
/** @inheritdoc */
def foreachPartition(f: Iterator[T] => Unit): Unit = {
// Delegate to mapPartition with empty result.
- mapPartitions(UdfUtils.foreachPartitionFuncToMapPartitionsAdaptor(f))(RowEncoder(Seq.empty))
- .collect()
- }
-
- /** @inheritdoc */
- override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = {
- foreachPartition(UdfUtils.foreachPartitionFuncToScalaFunc(func))
+ mapPartitions(UDFAdaptors.foreachPartitionToMapPartitions(f))(NullEncoder).collect()
}
/** @inheritdoc */
@@ -1047,51 +1019,12 @@ class Dataset[T] private[sql] (
new DataFrameWriterImpl[T](this)
}
- /**
- * Create a write configuration builder for v2 sources.
- *
- * This builder is used to configure and execute write operations. For example, to append to an
- * existing table, run:
- *
- * {{{
- * df.writeTo("catalog.db.table").append()
- * }}}
- *
- * This can also be used to create or replace existing tables:
- *
- * {{{
- * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace()
- * }}}
- *
- * @group basic
- * @since 3.4.0
- */
+ /** @inheritdoc */
def writeTo(table: String): DataFrameWriterV2[T] = {
- new DataFrameWriterV2[T](table, this)
+ new DataFrameWriterV2Impl[T](table, this)
}
- /**
- * Merges a set of updates, insertions, and deletions based on a source table into a target
- * table.
- *
- * Scala Examples:
- * {{{
- * spark.table("source")
- * .mergeInto("target", $"source.id" === $"target.id")
- * .whenMatched($"salary" === 100)
- * .delete()
- * .whenNotMatched()
- * .insertAll()
- * .whenNotMatchedBySource($"salary" === 100)
- * .update(Map(
- * "salary" -> lit(200)
- * ))
- * .merge()
- * }}}
- *
- * @group basic
- * @since 4.0.0
- */
+ /** @inheritdoc */
def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = {
if (isStreaming) {
throw new AnalysisException(
@@ -1099,7 +1032,7 @@ class Dataset[T] private[sql] (
messageParameters = Map("methodName" -> toSQLId("mergeInto")))
}
- new MergeIntoWriter[T](table, this, condition)
+ new MergeIntoWriterImpl[T](table, this, condition)
}
/**
@@ -1464,6 +1397,22 @@ class Dataset[T] private[sql] (
override def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] =
super.dropDuplicatesWithinWatermark(col1, cols: _*)
+ /** @inheritdoc */
+ override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ super.mapPartitions(f, encoder)
+
+ /** @inheritdoc */
+ override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] =
+ super.flatMap(func)
+
+ /** @inheritdoc */
+ override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ super.flatMap(f, encoder)
+
+ /** @inheritdoc */
+ override def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
+ super.foreachPartition(func)
+
/** @inheritdoc */
@scala.annotation.varargs
override def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] =
@@ -1515,4 +1464,10 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
@scala.annotation.varargs
override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*)
+
+ /** @inheritdoc */
+ override def groupByKey[K](
+ func: MapFunction[T, K],
+ encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
+ super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 04b620bdf8f98..6bf2518901470 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -19,17 +19,19 @@ package org.apache.spark.sql
import java.util.Arrays
+import scala.annotation.unused
import scala.jdk.CollectionConverters._
-import scala.language.existentials
import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
+import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.UdfUtils
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr
+import org.apache.spark.sql.internal.UDFAdaptors
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode}
/**
@@ -39,7 +41,10 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode
*
* @since 3.5.0
*/
-class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
+class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] {
+ type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]
+
+ private def unsupported(): Nothing = throw new UnsupportedOperationException()
/**
* Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
@@ -48,499 +53,52 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
*
* @since 3.5.0
*/
- def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
- throw new UnsupportedOperationException
- }
+ def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = unsupported()
- /**
- * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
- * the data. The grouping key is unchanged by this.
- *
- * {{{
- * // Create values grouped by key from a Dataset[(K, V)]
- * ds.groupByKey(_._1).mapValues(_._2) // Scala
- * }}}
- *
- * @since 3.5.0
- */
- def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] = {
- throw new UnsupportedOperationException
- }
+ /** @inheritdoc */
+ def mapValues[W: Encoder](valueFunc: V => W): KeyValueGroupedDataset[K, W] =
+ unsupported()
- /**
- * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
- * the data. The grouping key is unchanged by this.
- *
- * {{{
- * // Create Integer values grouped by String key from a Dataset>
- * Dataset> ds = ...;
- * KeyValueGroupedDataset grouped =
- * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
- * }}}
- *
- * @since 3.5.0
- */
- def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
- mapValues(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
- }
-
- /**
- * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
- * the Dataset to extract the keys and then running a distinct operation on those.
- *
- * @since 3.5.0
- */
- def keys: Dataset[K] = {
- throw new UnsupportedOperationException
- }
+ /** @inheritdoc */
+ def keys: Dataset[K] = unsupported()
- /**
- * (Scala-specific) Applies the given function to each group of data. For each unique group, the
- * function will be passed the group key and an iterator that contains all of the elements in
- * the group. The function can return an iterator containing elements of an arbitrary type which
- * will be returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the
- * memory constraints of their cluster.
- *
- * @since 3.5.0
- */
- def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
- flatMapSortedGroups()(f)
- }
-
- /**
- * (Java-specific) Applies the given function to each group of data. For each unique group, the
- * function will be passed the group key and an iterator that contains all of the elements in
- * the group. The function can return an iterator containing elements of an arbitrary type which
- * will be returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the
- * memory constraints of their cluster.
- *
- * @since 3.5.0
- */
- def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
- flatMapGroups(UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
- }
-
- /**
- * (Scala-specific) Applies the given function to each group of data. For each unique group, the
- * function will be passed the group key and a sorted iterator that contains all of the elements
- * in the group. The function can return an iterator containing elements of an arbitrary type
- * which will be returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the
- * memory constraints of their cluster.
- *
- * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
- * sorted according to the given sort expressions. That sorting does not add computational
- * complexity.
- *
- * @since 3.5.0
- */
+ /** @inheritdoc */
def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
- f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
- throw new UnsupportedOperationException
- }
-
- /**
- * (Java-specific) Applies the given function to each group of data. For each unique group, the
- * function will be passed the group key and a sorted iterator that contains all of the elements
- * in the group. The function can return an iterator containing elements of an arbitrary type
- * which will be returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the
- * memory constraints of their cluster.
- *
- * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
- * sorted according to the given sort expressions. That sorting does not add computational
- * complexity.
- *
- * @since 3.5.0
- */
- def flatMapSortedGroups[U](
- SortExprs: Array[Column],
- f: FlatMapGroupsFunction[K, V, U],
- encoder: Encoder[U]): Dataset[U] = {
- import org.apache.spark.util.ArrayImplicits._
- flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)(
- UdfUtils.flatMapGroupsFuncToScalaFunc(f))(encoder)
- }
-
- /**
- * (Scala-specific) Applies the given function to each group of data. For each unique group, the
- * function will be passed the group key and an iterator that contains all of the elements in
- * the group. The function can return an element of arbitrary type which will be returned as a
- * new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the
- * memory constraints of their cluster.
- *
- * @since 3.5.0
- */
- def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
- flatMapGroups(UdfUtils.mapGroupsFuncToFlatMapAdaptor(f))
- }
-
- /**
- * (Java-specific) Applies the given function to each group of data. For each unique group, the
- * function will be passed the group key and an iterator that contains all of the elements in
- * the group. The function can return an element of arbitrary type which will be returned as a
- * new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the
- * memory constraints of their cluster.
- *
- * @since 3.5.0
- */
- def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
- mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
- }
-
- /**
- * (Scala-specific) Reduces the elements of each group of data using the specified binary
- * function. The given function must be commutative and associative or the result may be
- * non-deterministic.
- *
- * @since 3.5.0
- */
- def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
- throw new UnsupportedOperationException
- }
-
- /**
- * (Java-specific) Reduces the elements of each group of data using the specified binary
- * function. The given function must be commutative and associative or the result may be
- * non-deterministic.
- *
- * @since 3.5.0
- */
- def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
- reduceGroups(UdfUtils.mapReduceFuncToScalaFunc(f))
- }
-
- /**
- * Internal helper function for building typed aggregations that return tuples. For simplicity
- * and code reuse, we do this without the help of the type system and then use helper functions
- * that cast appropriately for the user facing interface.
- */
- protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- throw new UnsupportedOperationException
- }
-
- /**
- * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the
- * result of computing this aggregation over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
- aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
-
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
- * the result of computing these aggregations over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
- aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
+ f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] =
+ unsupported()
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
- * the result of computing these aggregations over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1, U2, U3](
- col1: TypedColumn[V, U1],
- col2: TypedColumn[V, U2],
- col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
- aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
+ /** @inheritdoc */
+ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = unsupported()
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
- * the result of computing these aggregations over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1, U2, U3, U4](
- col1: TypedColumn[V, U1],
- col2: TypedColumn[V, U2],
- col3: TypedColumn[V, U3],
- col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
- aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
-
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
- * the result of computing these aggregations over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1, U2, U3, U4, U5](
- col1: TypedColumn[V, U1],
- col2: TypedColumn[V, U2],
- col3: TypedColumn[V, U3],
- col4: TypedColumn[V, U4],
- col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
- aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]]
+ /** @inheritdoc */
+ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = unsupported()
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
- * the result of computing these aggregations over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1, U2, U3, U4, U5, U6](
- col1: TypedColumn[V, U1],
- col2: TypedColumn[V, U2],
- col3: TypedColumn[V, U3],
- col4: TypedColumn[V, U4],
- col5: TypedColumn[V, U5],
- col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
- aggUntyped(col1, col2, col3, col4, col5, col6)
- .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]]
-
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
- * the result of computing these aggregations over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1, U2, U3, U4, U5, U6, U7](
- col1: TypedColumn[V, U1],
- col2: TypedColumn[V, U2],
- col3: TypedColumn[V, U3],
- col4: TypedColumn[V, U4],
- col5: TypedColumn[V, U5],
- col6: TypedColumn[V, U6],
- col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
- aggUntyped(col1, col2, col3, col4, col5, col6, col7)
- .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]]
-
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
- * the result of computing these aggregations over all elements in the group.
- *
- * @since 3.5.0
- */
- def agg[U1, U2, U3, U4, U5, U6, U7, U8](
- col1: TypedColumn[V, U1],
- col2: TypedColumn[V, U2],
- col3: TypedColumn[V, U3],
- col4: TypedColumn[V, U4],
- col5: TypedColumn[V, U5],
- col6: TypedColumn[V, U6],
- col7: TypedColumn[V, U7],
- col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
- aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8)
- .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]]
-
- /**
- * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for
- * that key.
- *
- * @since 3.5.0
- */
- def count(): Dataset[(K, Long)] = agg(functions.count("*"))
-
- /**
- * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
- * the function will be passed the grouping key and 2 iterators containing all elements in the
- * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
- * elements of an arbitrary type which will be returned as a new [[Dataset]].
- *
- * @since 3.5.0
- */
- def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
- f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
- cogroupSorted(other)()()(f)
- }
-
- /**
- * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
- * function will be passed the grouping key and 2 iterators containing all elements in the group
- * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
- * of an arbitrary type which will be returned as a new [[Dataset]].
- *
- * @since 3.5.0
- */
- def cogroup[U, R](
- other: KeyValueGroupedDataset[K, U],
- f: CoGroupFunction[K, V, U, R],
- encoder: Encoder[R]): Dataset[R] = {
- cogroup(other)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
- }
-
- /**
- * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
- * group, the function will be passed the grouping key and 2 sorted iterators containing all
- * elements in the group from [[Dataset]] `this` and `other`. The function can return an
- * iterator containing elements of an arbitrary type which will be returned as a new
- * [[Dataset]].
- *
- * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
- * sorted according to the given sort expressions. That sorting does not add computational
- * complexity.
- *
- * @since 3.5.0
- */
+ /** @inheritdoc */
def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(thisSortExprs: Column*)(
- otherSortExprs: Column*)(
- f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
- throw new UnsupportedOperationException
- }
-
- /**
- * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
- * group, the function will be passed the grouping key and 2 sorted iterators containing all
- * elements in the group from [[Dataset]] `this` and `other`. The function can return an
- * iterator containing elements of an arbitrary type which will be returned as a new
- * [[Dataset]].
- *
- * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
- * sorted according to the given sort expressions. That sorting does not add computational
- * complexity.
- *
- * @since 3.5.0
- */
- def cogroupSorted[U, R](
- other: KeyValueGroupedDataset[K, U],
- thisSortExprs: Array[Column],
- otherSortExprs: Array[Column],
- f: CoGroupFunction[K, V, U, R],
- encoder: Encoder[R]): Dataset[R] = {
- import org.apache.spark.util.ArrayImplicits._
- cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)(
- otherSortExprs.toImmutableArraySeq: _*)(UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
- }
+ otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] =
+ unsupported()
protected[sql] def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
outputMode: Option[OutputMode],
timeoutConf: GroupStateTimeout,
initialState: Option[KeyValueGroupedDataset[K, S]],
isMapGroupWithState: Boolean)(
- func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
- throw new UnsupportedOperationException
- }
+ func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = unsupported()
- /**
- * (Scala-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See
- * [[org.apache.spark.sql.streaming.GroupState]] for more details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
+ /** @inheritdoc */
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
mapGroupsWithState(GroupStateTimeout.NoTimeout)(func)
}
- /**
- * (Scala-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See
- * [[org.apache.spark.sql.streaming.GroupState]] for more details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param timeoutConf
- * Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
+ /** @inheritdoc */
def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)(
- UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+ UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func))
}
- /**
- * (Scala-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See
- * [[org.apache.spark.sql.streaming.GroupState]] for more details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param timeoutConf
- * Timeout Conf, see GroupStateTimeout for more details
- * @param initialState
- * The user provided state that will be initialized when the first batch of data is processed
- * in the streaming query. The user defined function will be called on the state data even if
- * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to
- * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
+ /** @inheritdoc */
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: GroupStateTimeout,
initialState: KeyValueGroupedDataset[K, S])(
@@ -549,134 +107,10 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
None,
timeoutConf,
Some(initialState),
- isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+ isMapGroupWithState = true)(UDFAdaptors.mapGroupsWithStateToFlatMapWithState(func))
}
- /**
- * (Java-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See `GroupState` for more
- * details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param stateEncoder
- * Encoder for the state type.
- * @param outputEncoder
- * Encoder for the output type.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
- def mapGroupsWithState[S, U](
- func: MapGroupsWithStateFunction[K, V, S, U],
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U]): Dataset[U] = {
- mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
- stateEncoder,
- outputEncoder)
- }
-
- /**
- * (Java-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See `GroupState` for more
- * details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param stateEncoder
- * Encoder for the state type.
- * @param outputEncoder
- * Encoder for the output type.
- * @param timeoutConf
- * Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
- def mapGroupsWithState[S, U](
- func: MapGroupsWithStateFunction[K, V, S, U],
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U],
- timeoutConf: GroupStateTimeout): Dataset[U] = {
- mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
- stateEncoder,
- outputEncoder)
- }
-
- /**
- * (Java-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See `GroupState` for more
- * details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param stateEncoder
- * Encoder for the state type.
- * @param outputEncoder
- * Encoder for the output type.
- * @param timeoutConf
- * Timeout configuration for groups that do not receive data for a while.
- * @param initialState
- * The user provided state that will be initialized when the first batch of data is processed
- * in the streaming query. The user defined function will be called on the state data even if
- * there are no other values in the group.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
- def mapGroupsWithState[S, U](
- func: MapGroupsWithStateFunction[K, V, S, U],
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U],
- timeoutConf: GroupStateTimeout,
- initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
- mapGroupsWithState[S, U](timeoutConf, initialState)(
- UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder)
- }
-
- /**
- * (Scala-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See `GroupState` for more
- * details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param outputMode
- * The output mode of the function.
- * @param timeoutConf
- * Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
+ /** @inheritdoc */
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout)(
@@ -688,33 +122,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
isMapGroupWithState = false)(func)
}
- /**
- * (Scala-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See `GroupState` for more
- * details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param outputMode
- * The output mode of the function.
- * @param timeoutConf
- * Timeout configuration for groups that do not receive data for a while.
- * @param initialState
- * The user provided state that will be initialized when the first batch of data is processed
- * in the streaming query. The user defined function will be called on the state data even if
- * there are no other values in the group. To covert a Dataset `ds` of type of type
- * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
- * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what
- * types are encodable to Spark SQL.
- * @since 3.5.0
- */
+ /** @inheritdoc */
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
@@ -727,201 +135,244 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
isMapGroupWithState = false)(func)
}
- /**
- * (Java-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See `GroupState` for more
- * details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param outputMode
- * The output mode of the function.
- * @param stateEncoder
- * Encoder for the state type.
- * @param outputEncoder
- * Encoder for the output type.
- * @param timeoutConf
- * Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
- def flatMapGroupsWithState[S, U](
+ /** @inheritdoc */
+ private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ timeMode: TimeMode,
+ outputMode: OutputMode): Dataset[U] =
+ unsupported()
+
+ /** @inheritdoc */
+ private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+ unsupported()
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ eventTimeColumnName: String,
+ outputMode: OutputMode): Dataset[U] = unsupported()
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ eventTimeColumnName: String,
+ outputMode: OutputMode,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = unsupported()
+
+ // Overrides...
+ /** @inheritdoc */
+ override def mapValues[W](
+ func: MapFunction[V, W],
+ encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder)
+
+ /** @inheritdoc */
+ override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] =
+ super.flatMapGroups(f)
+
+ /** @inheritdoc */
+ override def flatMapGroups[U](
+ f: FlatMapGroupsFunction[K, V, U],
+ encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder)
+
+ /** @inheritdoc */
+ override def flatMapSortedGroups[U](
+ SortExprs: Array[Column],
+ f: FlatMapGroupsFunction[K, V, U],
+ encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder)
+
+ /** @inheritdoc */
+ override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f)
+
+ /** @inheritdoc */
+ override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] =
+ super.mapGroups(f, encoder)
+
+ /** @inheritdoc */
+ override def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] =
+ super.mapGroupsWithState(func, stateEncoder, outputEncoder)
+
+ /** @inheritdoc */
+ override def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout): Dataset[U] =
+ super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf)
+
+ /** @inheritdoc */
+ override def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+ super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState)
+
+ /** @inheritdoc */
+ override def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
- timeoutConf: GroupStateTimeout): Dataset[U] = {
- val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
- flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
- }
+ timeoutConf: GroupStateTimeout): Dataset[U] =
+ super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf)
- /**
- * (Java-specific) Applies the given function to each group of data, while maintaining a
- * user-defined per-group state. The result Dataset will represent the objects returned by the
- * function. For a static batch Dataset, the function will be invoked once per group. For a
- * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
- * and updates to each group's state will be saved across invocations. See `GroupState` for more
- * details.
- *
- * @tparam S
- * The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param func
- * Function to be called on every group.
- * @param outputMode
- * The output mode of the function.
- * @param stateEncoder
- * Encoder for the state type.
- * @param outputEncoder
- * Encoder for the output type.
- * @param timeoutConf
- * Timeout configuration for groups that do not receive data for a while.
- * @param initialState
- * The user provided state that will be initialized when the first batch of data is processed
- * in the streaming query. The user defined function will be called on the state data even if
- * there are no other values in the group. To covert a Dataset `ds` of type of type
- * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
- * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.5.0
- */
- def flatMapGroupsWithState[S, U](
+ /** @inheritdoc */
+ override def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: GroupStateTimeout,
- initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
- val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
- flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)(
- stateEncoder,
- outputEncoder)
- }
-
- /**
- * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
- * API v2. We allow the user to act on per-group set of input rows along with keyed state and
- * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will
- * repeatedly invoke the interface methods for new rows in each trigger and the user's
- * state/state variables will be stored persistently across invocations. Currently this operator
- * is not supported with Spark Connect.
- *
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor
- * Instance of statefulProcessor whose functions will be invoked by the operator.
- * @param timeMode
- * The time mode semantics of the stateful processor for timers and TTL.
- * @param outputMode
- * The output mode of the stateful processor.
- */
- private[sql] def transformWithState[U: Encoder](
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = super.flatMapGroupsWithState(
+ func,
+ outputMode,
+ stateEncoder,
+ outputEncoder,
+ timeoutConf,
+ initialState)
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
timeMode: TimeMode,
- outputMode: OutputMode): Dataset[U] = {
- throw new UnsupportedOperationException
- }
+ outputMode: OutputMode,
+ outputEncoder: Encoder[U]) =
+ super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder)
- /**
- * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
- * v2. We allow the user to act on per-group set of input rows along with keyed state and the
- * user can choose to output/return 0 or more rows. For a streaming dataframe, we will
- * repeatedly invoke the interface methods for new rows in each trigger and the user's
- * state/state variables will be stored persistently across invocations. Currently this operator
- * is not supported with Spark Connect.
- *
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor
- * Instance of statefulProcessor whose functions will be invoked by the operator.
- * @param timeMode
- * The time mode semantics of the stateful processor for timers and TTL.
- * @param outputMode
- * The output mode of the stateful processor.
- * @param outputEncoder
- * Encoder for the output type.
- */
- private[sql] def transformWithState[U: Encoder](
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
- timeMode: TimeMode,
+ eventTimeColumnName: String,
outputMode: OutputMode,
- outputEncoder: Encoder[U]): Dataset[U] = {
- throw new UnsupportedOperationException
- }
+ outputEncoder: Encoder[U]) =
+ super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder)
- /**
- * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
- * API v2. Functions as the function above, but with additional initial state. Currently this
- * operator is not supported with Spark Connect.
- *
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @tparam S
- * The type of initial state objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor
- * Instance of statefulProcessor whose functions will be invoked by the operator.
- * @param timeMode
- * The time mode semantics of the stateful processor for timers and TTL.
- * @param outputMode
- * The output mode of the stateful processor.
- * @param initialState
- * User provided initial state that will be used to initiate state for the query in the first
- * batch.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
- private[sql] def transformWithState[U: Encoder, S: Encoder](
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
timeMode: TimeMode,
outputMode: OutputMode,
- initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
- throw new UnsupportedOperationException
- }
-
- /**
- * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
- * v2. Functions as the function above, but with additional initial state. Currently this
- * operator is not supported with Spark Connect.
- *
- * @tparam U
- * The type of the output objects. Must be encodable to Spark SQL types.
- * @tparam S
- * The type of initial state objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor
- * Instance of statefulProcessor whose functions will be invoked by the operator.
- * @param timeMode
- * The time mode semantics of the stateful processor for timers and TTL.
- * @param outputMode
- * The output mode of the stateful processor.
- * @param initialState
- * User provided initial state that will be used to initiate state for the query in the first
- * batch.
- * @param outputEncoder
- * Encoder for the output type.
- * @param initialStateEncoder
- * Encoder for the initial state type.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
- private[sql] def transformWithState[U: Encoder, S: Encoder](
+ initialState: KeyValueGroupedDataset[K, S],
+ outputEncoder: Encoder[U],
+ initialStateEncoder: Encoder[S]) = super.transformWithState(
+ statefulProcessor,
+ timeMode,
+ outputMode,
+ initialState,
+ outputEncoder,
+ initialStateEncoder)
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
- timeMode: TimeMode,
outputMode: OutputMode,
initialState: KeyValueGroupedDataset[K, S],
+ eventTimeColumnName: String,
outputEncoder: Encoder[U],
- initialStateEncoder: Encoder[S]): Dataset[U] = {
- throw new UnsupportedOperationException
- }
+ initialStateEncoder: Encoder[S]) = super.transformWithState(
+ statefulProcessor,
+ outputMode,
+ initialState,
+ eventTimeColumnName,
+ outputEncoder,
+ initialStateEncoder)
+
+ /** @inheritdoc */
+ override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f)
+
+ /** @inheritdoc */
+ override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1)
+
+ /** @inheritdoc */
+ override def agg[U1, U2](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2)
+
+ /** @inheritdoc */
+ override def agg[U1, U2, U3](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3)
+
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4)
+
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
+ super.agg(col1, col2, col3, col4, col5)
+
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5, U6](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5],
+ col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
+ super.agg(col1, col2, col3, col4, col5, col6)
+
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5, U6, U7](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5],
+ col6: TypedColumn[V, U6],
+ col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
+ super.agg(col1, col2, col3, col4, col5, col6, col7)
+
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5, U6, U7, U8](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5],
+ col6: TypedColumn[V, U6],
+ col7: TypedColumn[V, U7],
+ col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
+ super.agg(col1, col2, col3, col4, col5, col6, col7, col8)
+
+ /** @inheritdoc */
+ override def count(): Dataset[(K, Long)] = super.count()
+
+ /** @inheritdoc */
+ override def cogroup[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
+ f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] =
+ super.cogroup(other)(f)
+
+ /** @inheritdoc */
+ override def cogroup[U, R](
+ other: KeyValueGroupedDataset[K, U],
+ f: CoGroupFunction[K, V, U, R],
+ encoder: Encoder[R]): Dataset[R] = super.cogroup(other, f, encoder)
+
+ /** @inheritdoc */
+ override def cogroupSorted[U, R](
+ other: KeyValueGroupedDataset[K, U],
+ thisSortExprs: Array[Column],
+ otherSortExprs: Array[Column],
+ f: CoGroupFunction[K, V, U, R],
+ encoder: Encoder[R]): Dataset[R] =
+ super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder)
}
/**
@@ -934,12 +385,11 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
private val sparkSession: SparkSession,
private val plan: proto.Plan,
- private val ikEncoder: AgnosticEncoder[IK],
private val kEncoder: AgnosticEncoder[K],
private val ivEncoder: AgnosticEncoder[IV],
private val vEncoder: AgnosticEncoder[V],
private val groupingExprs: java.util.List[proto.Expression],
- private val valueMapFunc: IV => V,
+ private val valueMapFunc: Option[IV => V],
private val keysFunc: () => Dataset[IK])
extends KeyValueGroupedDataset[K, V] {
import sparkSession.RichColumn
@@ -948,7 +398,6 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
new KeyValueGroupedDatasetImpl[L, V, IK, IV](
sparkSession,
plan,
- ikEncoder,
encoderFor[L],
ivEncoder,
vEncoder,
@@ -961,12 +410,13 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
new KeyValueGroupedDatasetImpl[K, W, IK, IV](
sparkSession,
plan,
- ikEncoder,
kEncoder,
ivEncoder,
encoderFor[W],
groupingExprs,
- valueMapFunc.andThen(valueFunc),
+ valueMapFunc
+ .map(_.andThen(valueFunc))
+ .orElse(Option(valueFunc.asInstanceOf[IV => W])),
keysFunc)
}
@@ -979,8 +429,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
override def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
// Apply mapValues changes to the udf
- val nf =
- if (valueMapFunc == UdfUtils.identical()) f else UdfUtils.mapValuesAdaptor(f, valueMapFunc)
+ val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc)
val outputEncoder = encoderFor[U]
sparkSession.newDataset[U](outputEncoder) { builder =>
builder.getGroupMapBuilder
@@ -994,10 +443,9 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
override def cogroupSorted[U, R: Encoder](other: KeyValueGroupedDataset[K, U])(
thisSortExprs: Column*)(otherSortExprs: Column*)(
f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
- assert(other.isInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]])
- val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, _]]
+ val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]]
// Apply mapValues changes to the udf
- val nf = UdfUtils.mapValuesAdaptor(f, valueMapFunc, otherImpl.valueMapFunc)
+ val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc, otherImpl.valueMapFunc)
val outputEncoder = encoderFor[R]
sparkSession.newDataset[R](outputEncoder) { builder =>
builder.getCoGroupMapBuilder
@@ -1012,8 +460,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
}
override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- // TODO(SPARK-43415): For each column, apply the valueMap func first
- // apply keyAs change
+ // TODO(SPARK-43415): For each column, apply the valueMap func first...
val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => encoderFor(c.encoder)))
sparkSession.newDataset(rEnc) { builder =>
builder.getAggregateBuilder
@@ -1047,22 +494,15 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
throw new IllegalArgumentException("The output mode of function should be append or update")
}
- if (initialState.isDefined) {
- assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]])
- }
-
val initialStateImpl = if (initialState.isDefined) {
+ assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]])
initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
} else {
null
}
val outputEncoder = encoderFor[U]
- val nf = if (valueMapFunc == UdfUtils.identical()) {
- func
- } else {
- UdfUtils.mapValuesAdaptor(func, valueMapFunc)
- }
+ val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func, valueMapFunc)
sparkSession.newDataset[U](outputEncoder) { builder =>
val groupMapBuilder = builder.getGroupMapBuilder
@@ -1097,6 +537,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
* We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a class clash on the
* server side. We null out the instance for now.
*/
+ @unused("this is used by java serialization")
private def writeReplace(): Any = null
}
@@ -1114,11 +555,10 @@ private object KeyValueGroupedDatasetImpl {
session,
ds.plan,
kEncoder,
- kEncoder,
ds.agnosticEncoder,
ds.agnosticEncoder,
Arrays.asList(toExpr(gf.apply(col("*")))),
- UdfUtils.identical(),
+ None,
() => ds.map(groupingFunc)(kEncoder))
}
@@ -1137,11 +577,10 @@ private object KeyValueGroupedDatasetImpl {
session,
df.plan,
kEncoder,
- kEncoder,
vEncoder,
vEncoder,
(Seq(dummyGroupingFunc) ++ groupingExprs).map(toExpr).asJava,
- UdfUtils.identical(),
+ None,
() => df.select(groupingExprs: _*).as(kEncoder))
}
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index c9b011ca4535b..14ceb3f4bb144 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.ConnectConversions._
/**
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
@@ -39,8 +40,7 @@ class RelationalGroupedDataset private[sql] (
groupType: proto.Aggregate.GroupType,
pivot: Option[proto.Aggregate.Pivot] = None,
groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None)
- extends api.RelationalGroupedDataset[Dataset] {
- type RGD = RelationalGroupedDataset
+ extends api.RelationalGroupedDataset {
import df.sparkSession.RichColumn
protected def toDF(aggExprs: Seq[Column]): DataFrame = {
@@ -80,12 +80,7 @@ class RelationalGroupedDataset private[sql] (
colNames.map(df.col)
}
- /**
- * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of
- * current `RelationalGroupedDataset`.
- *
- * @since 3.5.0
- */
+ /** @inheritdoc */
def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs)
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 24d0a5ac7262f..04f8eeb5c6d46 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql
import java.net.URI
+import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
+import scala.util.Try
import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
@@ -39,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf}
+import org.apache.spark.sql.internal.{CatalogImpl, ConnectRuntimeConfig, SessionCleaner, SqlApiConf}
import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
@@ -67,7 +69,7 @@ import org.apache.spark.util.ArrayImplicits._
class SparkSession private[sql] (
private[sql] val client: SparkConnectClient,
private val planIdGenerator: AtomicLong)
- extends api.SparkSession[Dataset]
+ extends api.SparkSession
with Logging {
private[this] val allocator = new RootAllocator()
@@ -86,16 +88,8 @@ class SparkSession private[sql] (
client.hijackServerSideSessionIdForTesting(suffix)
}
- /**
- * Runtime configuration interface for Spark.
- *
- * This is the interface through which the user can get and set all Spark configurations that
- * are relevant to Spark SQL. When getting the value of a config, his defaults to the value set
- * in server, if any.
- *
- * @since 3.4.0
- */
- val conf: RuntimeConfig = new RuntimeConfig(client)
+ /** @inheritdoc */
+ val conf: RuntimeConfig = new ConnectRuntimeConfig(client)
/** @inheritdoc */
@transient
@@ -212,16 +206,7 @@ class SparkSession private[sql] (
sql(query, Array.empty)
}
- /**
- * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
- * `DataFrame`.
- * {{{
- * sparkSession.read.parquet("/path/to/file.parquet")
- * sparkSession.read.schema(schema).json("/path/to/file.json")
- * }}}
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def read: DataFrameReader = new DataFrameReader(this)
/**
@@ -237,12 +222,7 @@ class SparkSession private[sql] (
lazy val streams: StreamingQueryManager = new StreamingQueryManager(this)
- /**
- * Interface through which the user may create, drop, alter or query underlying databases,
- * tables, functions etc.
- *
- * @since 3.5.0
- */
+ /** @inheritdoc */
lazy val catalog: Catalog = new CatalogImpl(this)
/** @inheritdoc */
@@ -440,7 +420,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
- def interruptAll(): Seq[String] = {
+ override def interruptAll(): Seq[String] = {
client.interruptAll().getInterruptedIdsList.asScala.toSeq
}
@@ -453,7 +433,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
- def interruptTag(tag: String): Seq[String] = {
+ override def interruptTag(tag: String): Seq[String] = {
client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
}
@@ -466,7 +446,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
- def interruptOperation(operationId: String): Seq[String] = {
+ override def interruptOperation(operationId: String): Seq[String] = {
client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
}
@@ -497,65 +477,17 @@ class SparkSession private[sql] (
SparkSession.onSessionClose(this)
}
- /**
- * Add a tag to be assigned to all the operations started by this thread in this session.
- *
- * Often, a unit of execution in an application consists of multiple Spark executions.
- * Application programmers can use this method to group all those jobs together and give a group
- * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all
- * running running executions with this tag. For example:
- * {{{
- * // In the main thread:
- * spark.addTag("myjobs")
- * spark.range(10).map(i => { Thread.sleep(10); i }).collect()
- *
- * // In a separate thread:
- * spark.interruptTag("myjobs")
- * }}}
- *
- * There may be multiple tags present at the same time, so different parts of application may
- * use different tags to perform cancellation at different levels of granularity.
- *
- * @param tag
- * The tag to be added. Cannot contain ',' (comma) character or be an empty string.
- *
- * @since 3.5.0
- */
- def addTag(tag: String): Unit = {
- client.addTag(tag)
- }
+ /** @inheritdoc */
+ override def addTag(tag: String): Unit = client.addTag(tag)
- /**
- * Remove a tag previously added to be assigned to all the operations started by this thread in
- * this session. Noop if such a tag was not added earlier.
- *
- * @param tag
- * The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
- *
- * @since 3.5.0
- */
- def removeTag(tag: String): Unit = {
- client.removeTag(tag)
- }
+ /** @inheritdoc */
+ override def removeTag(tag: String): Unit = client.removeTag(tag)
- /**
- * Get the tags that are currently set to be assigned to all the operations started by this
- * thread.
- *
- * @since 3.5.0
- */
- def getTags(): Set[String] = {
- client.getTags()
- }
+ /** @inheritdoc */
+ override def getTags(): Set[String] = client.getTags()
- /**
- * Clear the current thread's operation tags.
- *
- * @since 3.5.0
- */
- def clearTags(): Unit = {
- client.clearTags()
- }
+ /** @inheritdoc */
+ override def clearTags(): Unit = client.clearTags()
/**
* We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side.
@@ -591,6 +523,10 @@ class SparkSession private[sql] (
object SparkSession extends Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
+ private var server: Option[Process] = None
+ private[sql] val sparkOptions = sys.props.filter { p =>
+ p._1.startsWith("spark.") && p._2.nonEmpty
+ }.toMap
private val sessions = CacheBuilder
.newBuilder()
@@ -623,6 +559,51 @@ object SparkSession extends Logging {
}
}
+ /**
+ * Create a new Spark Connect server to connect locally.
+ */
+ private[sql] def withLocalConnectServer[T](f: => T): T = {
+ synchronized {
+ val remoteString = sparkOptions
+ .get("spark.remote")
+ .orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
+ .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
+
+ val maybeConnectScript =
+ Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
+
+ if (server.isEmpty &&
+ remoteString.exists(_.startsWith("local")) &&
+ maybeConnectScript.exists(Files.exists(_))) {
+ server = Some {
+ val args =
+ Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
+ .filter(p => !p._1.startsWith("spark.remote"))
+ .flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
+ val pb = new ProcessBuilder(args: _*)
+ // So don't exclude spark-sql jar in classpath
+ pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
+ pb.start()
+ }
+
+ // Let the server start. We will directly request to set the configurations
+ // and this sleep makes less noisy with retries.
+ Thread.sleep(2000L)
+ System.setProperty("spark.remote", "sc://localhost")
+
+ // scalastyle:off runtimeaddshutdownhook
+ Runtime.getRuntime.addShutdownHook(new Thread() {
+ override def run(): Unit = if (server.isDefined) {
+ new ProcessBuilder(maybeConnectScript.get.toString)
+ .start()
+ }
+ })
+ // scalastyle:on runtimeaddshutdownhook
+ }
+ }
+ f
+ }
+
/**
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
@@ -765,6 +746,16 @@ object SparkSession extends Logging {
}
private def applyOptions(session: SparkSession): Unit = {
+ // Only attempts to set Spark SQL configurations.
+ // If the configurations are static, it might throw an exception so
+ // simply ignore it for now.
+ sparkOptions
+ .filter { case (k, _) =>
+ k.startsWith("spark.sql.")
+ }
+ .foreach { case (key, value) =>
+ Try(session.conf.set(key, value))
+ }
options.foreach { case (key, value) =>
session.conf.set(key, value)
}
@@ -787,7 +778,7 @@ object SparkSession extends Logging {
*
* @since 3.5.0
*/
- def create(): SparkSession = {
+ def create(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse(SparkSession.this.create(builder.configuration))
setDefaultAndActiveSession(session)
@@ -807,7 +798,7 @@ object SparkSession extends Logging {
*
* @since 3.5.0
*/
- def getOrCreate(): SparkSession = {
+ def getOrCreate(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse({
var existingSession = sessions.get(builder.configuration)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
index 86775803a0937..63fa2821a6c6a 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql.application
import java.io.{InputStream, OutputStream}
-import java.nio.file.Paths
import java.util.concurrent.Semaphore
-import scala.util.Try
import scala.util.control.NonFatal
import ammonite.compiler.CodeClassWrapper
@@ -34,6 +32,7 @@ import ammonite.util.Util.newLine
import org.apache.spark.SparkBuildInfo.spark_version
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.SparkSession.withLocalConnectServer
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkConnectClientParser}
/**
@@ -64,37 +63,7 @@ Spark session available as 'spark'.
semaphore: Option[Semaphore] = None,
inputStream: InputStream = System.in,
outputStream: OutputStream = System.out,
- errorStream: OutputStream = System.err): Unit = {
- val configs: Map[String, String] =
- sys.props
- .filter(p =>
- p._1.startsWith("spark.") &&
- p._2.nonEmpty &&
- // Don't include spark.remote that we manually set later.
- !p._1.startsWith("spark.remote"))
- .toMap
-
- val remoteString: Option[String] =
- Option(System.getProperty("spark.remote")) // Set from Spark Submit
- .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
-
- if (remoteString.exists(_.startsWith("local"))) {
- server = Some {
- val args = Seq(
- Paths.get(sparkHome, "sbin", "start-connect-server.sh").toString,
- "--master",
- remoteString.get) ++ configs.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
- val pb = new ProcessBuilder(args: _*)
- // So don't exclude spark-sql jar in classpath
- pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
- pb.start()
- }
- // Let the server start. We will directly request to set the configurations
- // and this sleep makes less noisy with retries.
- Thread.sleep(2000L)
- System.setProperty("spark.remote", "sc://localhost")
- }
-
+ errorStream: OutputStream = System.err): Unit = withLocalConnectServer {
// Build the client.
val client =
try {
@@ -118,13 +87,6 @@ Spark session available as 'spark'.
// Build the session.
val spark = SparkSession.builder().client(client).getOrCreate()
-
- // The configurations might not be all runtime configurations.
- // Try to set them with ignoring failures for now.
- configs
- .filter(_._1.startsWith("spark.sql"))
- .foreach { case (k, v) => Try(spark.conf.set(k, v)) }
-
val sparkBind = new Bind("spark", spark)
// Add the proper imports and register a [[ClassFinder]].
@@ -197,18 +159,12 @@ Spark session available as 'spark'.
}
}
}
- try {
- if (semaphore.nonEmpty) {
- // Used for testing.
- main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
- } else {
- main.run(sparkBind)
- }
- } finally {
- if (server.isDefined) {
- new ProcessBuilder(Paths.get(sparkHome, "sbin", "stop-connect-server.sh").toString)
- .start()
- }
+
+ if (semaphore.nonEmpty) {
+ // Used for testing.
+ main.run(sparkBind, new Bind[Semaphore]("semaphore", semaphore.get))
+ } else {
+ main.run(sparkBind)
}
}
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index cf0fef147ee84..86b1dbe4754e6 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -17,660 +17,152 @@
package org.apache.spark.sql.catalog
-import scala.jdk.CollectionConverters._
+import java.util
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset}
+import org.apache.spark.sql.{api, DataFrame, Dataset}
+import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.types.StructType
-import org.apache.spark.storage.StorageLevel
-/**
- * Catalog interface for Spark. To access this, use `SparkSession.catalog`.
- *
- * @since 3.5.0
- */
-abstract class Catalog {
-
- /**
- * Returns the current database (namespace) in this session.
- *
- * @since 3.5.0
- */
- def currentDatabase: String
-
- /**
- * Sets the current database (namespace) in this session.
- *
- * @since 3.5.0
- */
- def setCurrentDatabase(dbName: String): Unit
-
- /**
- * Returns a list of databases (namespaces) available within the current catalog.
- *
- * @since 3.5.0
- */
- def listDatabases(): Dataset[Database]
-
- /**
- * Returns a list of databases (namespaces) which name match the specify pattern and available
- * within the current catalog.
- *
- * @since 3.5.0
- */
- def listDatabases(pattern: String): Dataset[Database]
-
- /**
- * Returns a list of tables/views in the current database (namespace). This includes all
- * temporary views.
- *
- * @since 3.5.0
- */
- def listTables(): Dataset[Table]
-
- /**
- * Returns a list of tables/views in the specified database (namespace) (the name can be
- * qualified with catalog). This includes all temporary views.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database does not exist")
- def listTables(dbName: String): Dataset[Table]
-
- /**
- * Returns a list of tables/views in the specified database (namespace) which name match the
- * specify pattern (the name can be qualified with catalog). This includes all temporary views.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database does not exist")
- def listTables(dbName: String, pattern: String): Dataset[Table]
-
- /**
- * Returns a list of functions registered in the current database (namespace). This includes all
- * temporary functions.
- *
- * @since 3.5.0
- */
- def listFunctions(): Dataset[Function]
-
- /**
- * Returns a list of functions registered in the specified database (namespace) (the name can be
- * qualified with catalog). This includes all built-in and temporary functions.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database does not exist")
- def listFunctions(dbName: String): Dataset[Function]
-
- /**
- * Returns a list of functions registered in the specified database (namespace) which name match
- * the specify pattern (the name can be qualified with catalog). This includes all built-in and
- * temporary functions.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database does not exist")
- def listFunctions(dbName: String, pattern: String): Dataset[Function]
-
- /**
- * Returns a list of columns for the given table/view or temporary view.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. It follows the same
- * resolution rule with SQL: search for temp views first then table/views in the current
- * database (namespace).
- * @since 3.5.0
- */
- @throws[AnalysisException]("table does not exist")
- def listColumns(tableName: String): Dataset[Column]
-
- /**
- * Returns a list of columns for the given table/view in the specified database under the Hive
- * Metastore.
- *
- * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with
- * qualified table/view name instead.
- *
- * @param dbName
- * is an unqualified name that designates a database.
- * @param tableName
- * is an unqualified name that designates a table/view.
- * @since 3.5.0
- */
- @throws[AnalysisException]("database or table does not exist")
- def listColumns(dbName: String, tableName: String): Dataset[Column]
-
- /**
- * Get the database (namespace) with the specified name (can be qualified with catalog). This
- * throws an AnalysisException when the database (namespace) cannot be found.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database does not exist")
- def getDatabase(dbName: String): Database
-
- /**
- * Get the table or view with the specified name. This table can be a temporary view or a
- * table/view. This throws an AnalysisException when no Table can be found.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. It follows the same
- * resolution rule with SQL: search for temp views first then table/views in the current
- * database (namespace).
- * @since 3.5.0
- */
- @throws[AnalysisException]("table does not exist")
- def getTable(tableName: String): Table
-
- /**
- * Get the table or view with the specified name in the specified database under the Hive
- * Metastore. This throws an AnalysisException when no Table can be found.
- *
- * To get table/view in other catalogs, please use `getTable(tableName)` with qualified
- * table/view name instead.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database or table does not exist")
- def getTable(dbName: String, tableName: String): Table
-
- /**
- * Get the function with the specified name. This function can be a temporary function or a
- * function. This throws an AnalysisException when the function cannot be found.
- *
- * @param functionName
- * is either a qualified or unqualified name that designates a function. It follows the same
- * resolution rule with SQL: search for built-in/temp functions first then functions in the
- * current database (namespace).
- * @since 3.5.0
- */
- @throws[AnalysisException]("function does not exist")
- def getFunction(functionName: String): Function
-
- /**
- * Get the function with the specified name in the specified database under the Hive Metastore.
- * This throws an AnalysisException when the function cannot be found.
- *
- * To get functions in other catalogs, please use `getFunction(functionName)` with qualified
- * function name instead.
- *
- * @param dbName
- * is an unqualified name that designates a database.
- * @param functionName
- * is an unqualified name that designates a function in the specified database
- * @since 3.5.0
- */
- @throws[AnalysisException]("database or function does not exist")
- def getFunction(dbName: String, functionName: String): Function
-
- /**
- * Check if the database (namespace) with the specified name exists (the name can be qualified
- * with catalog).
- *
- * @since 3.5.0
- */
- def databaseExists(dbName: String): Boolean
-
- /**
- * Check if the table or view with the specified name exists. This can either be a temporary
- * view or a table/view.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. It follows the same
- * resolution rule with SQL: search for temp views first then table/views in the current
- * database (namespace).
- * @since 3.5.0
- */
- def tableExists(tableName: String): Boolean
-
- /**
- * Check if the table or view with the specified name exists in the specified database under the
- * Hive Metastore.
- *
- * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with
- * qualified table/view name instead.
- *
- * @param dbName
- * is an unqualified name that designates a database.
- * @param tableName
- * is an unqualified name that designates a table.
- * @since 3.5.0
- */
- def tableExists(dbName: String, tableName: String): Boolean
-
- /**
- * Check if the function with the specified name exists. This can either be a temporary function
- * or a function.
- *
- * @param functionName
- * is either a qualified or unqualified name that designates a function. It follows the same
- * resolution rule with SQL: search for built-in/temp functions first then functions in the
- * current database (namespace).
- * @since 3.5.0
- */
- def functionExists(functionName: String): Boolean
-
- /**
- * Check if the function with the specified name exists in the specified database under the Hive
- * Metastore.
- *
- * To check existence of functions in other catalogs, please use `functionExists(functionName)`
- * with qualified function name instead.
- *
- * @param dbName
- * is an unqualified name that designates a database.
- * @param functionName
- * is an unqualified name that designates a function.
- * @since 3.5.0
- */
- def functionExists(dbName: String, functionName: String): Boolean
-
- /**
- * Creates a table from the given path and returns the corresponding DataFrame. It will use the
- * default data source configured by spark.sql.sources.default.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(tableName: String, path: String): DataFrame = {
- createTable(tableName, path)
- }
-
- /**
- * Creates a table from the given path and returns the corresponding DataFrame. It will use the
- * default data source configured by spark.sql.sources.default.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(tableName: String, path: String): DataFrame
-
- /**
- * Creates a table from the given path based on a data source and returns the corresponding
- * DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(tableName: String, path: String, source: String): DataFrame = {
- createTable(tableName, path, source)
- }
-
- /**
- * Creates a table from the given path based on a data source and returns the corresponding
- * DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(tableName: String, path: String, source: String): DataFrame
-
- /**
- * Creates a table from the given path based on a data source and a set of options. Then,
- * returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+/** @inheritdoc */
+abstract class Catalog extends api.Catalog {
+
+ /** @inheritdoc */
+ override def listDatabases(): Dataset[Database]
+
+ /** @inheritdoc */
+ override def listDatabases(pattern: String): Dataset[Database]
+
+ /** @inheritdoc */
+ override def listTables(): Dataset[Table]
+
+ /** @inheritdoc */
+ override def listTables(dbName: String): Dataset[Table]
+
+ /** @inheritdoc */
+ override def listTables(dbName: String, pattern: String): Dataset[Table]
+
+ /** @inheritdoc */
+ override def listFunctions(): Dataset[Function]
+
+ /** @inheritdoc */
+ override def listFunctions(dbName: String): Dataset[Function]
+
+ /** @inheritdoc */
+ override def listFunctions(dbName: String, pattern: String): Dataset[Function]
+
+ /** @inheritdoc */
+ override def listColumns(tableName: String): Dataset[Column]
+
+ /** @inheritdoc */
+ override def listColumns(dbName: String, tableName: String): Dataset[Column]
+
+ /** @inheritdoc */
+ override def createTable(tableName: String, path: String): DataFrame
+
+ /** @inheritdoc */
+ override def createTable(tableName: String, path: String, source: String): DataFrame
+
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, options)
- }
-
- /**
- * Creates a table based on the dataset in a data source and a set of options. Then, returns the
- * corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(
+ options: Map[String, String]): DataFrame
+
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, options.asScala.toMap)
- }
-
- /**
- * (Scala-specific) Creates a table from the given path based on a data source and a set of
- * options. Then, returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+ description: String,
+ options: Map[String, String]): DataFrame
+
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- options: Map[String, String]): DataFrame = {
- createTable(tableName, source, options)
- }
-
- /**
- * (Scala-specific) Creates a table based on the dataset in a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(tableName: String, source: String, options: Map[String, String]): DataFrame
-
- /**
- * Create a table from the given path based on a data source, a schema and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+ schema: StructType,
+ options: Map[String, String]): DataFrame
+
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
schema: StructType,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, schema, options)
- }
-
- /**
- * Creates a table based on the dataset in a data source and a set of options. Then, returns the
- * corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(
+ description: String,
+ options: Map[String, String]): DataFrame
+
+ /** @inheritdoc */
+ override def listCatalogs(): Dataset[CatalogMetadata]
+
+ /** @inheritdoc */
+ override def listCatalogs(pattern: String): Dataset[CatalogMetadata]
+
+ /** @inheritdoc */
+ override def createExternalTable(tableName: String, path: String): DataFrame =
+ super.createExternalTable(tableName, path)
+
+ /** @inheritdoc */
+ override def createExternalTable(tableName: String, path: String, source: String): DataFrame =
+ super.createExternalTable(tableName, path, source)
+
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
- description: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(
- tableName,
- source = source,
- description = description,
- options = options.asScala.toMap)
- }
-
- /**
- * (Scala-specific) Creates a table based on the dataset in a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(
+ options: util.Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, options)
+
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- description: String,
- options: Map[String, String]): DataFrame
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, options)
- /**
- * Create a table based on the dataset in a data source, a schema and a set of options. Then,
- * returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
- schema: StructType,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, schema, options.asScala.toMap)
- }
-
- /**
- * (Scala-specific) Create a table from the given path based on a data source, a schema and a
- * set of options. Then, returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+ options: Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, options)
+
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
schema: StructType,
- options: Map[String, String]): DataFrame = {
- createTable(tableName, source, schema, options)
- }
-
- /**
- * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of
- * options. Then, returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(
+ options: util.Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, schema, options)
+
+ /** @inheritdoc */
+ override def createTable(
+ tableName: String,
+ source: String,
+ description: String,
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, description, options)
+
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
schema: StructType,
- options: Map[String, String]): DataFrame
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, schema, options)
- /**
- * Create a table based on the dataset in a data source, a schema and a set of options. Then,
- * returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
schema: StructType,
- description: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(
- tableName,
- source = source,
- schema = schema,
- description = description,
- options = options.asScala.toMap)
- }
-
- /**
- * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of
- * options. Then, returns the corresponding DataFrame.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def createTable(
+ options: Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, schema, options)
+
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
schema: StructType,
description: String,
- options: Map[String, String]): DataFrame
-
- /**
- * Drops the local temporary view with the given view name in the catalog. If the view has been
- * cached before, then it will also be uncached.
- *
- * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that
- * created it, i.e. it will be automatically dropped when the session terminates. It's not tied
- * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view.
- *
- * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean in
- * Spark 2.1.
- *
- * @param viewName
- * the name of the temporary view to be dropped.
- * @return
- * true if the view is dropped successfully, false otherwise.
- * @since 3.5.0
- */
- def dropTempView(viewName: String): Boolean
-
- /**
- * Drops the global temporary view with the given view name in the catalog. If the view has been
- * cached before, then it will also be uncached.
- *
- * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark
- * application, i.e. it will be automatically dropped when the application terminates. It's tied
- * to a system preserved database `global_temp`, and we must use the qualified name to refer a
- * global temp view, e.g. `SELECT * FROM global_temp.view1`.
- *
- * @param viewName
- * the unqualified name of the temporary view to be dropped.
- * @return
- * true if the view is dropped successfully, false otherwise.
- * @since 3.5.0
- */
- def dropGlobalTempView(viewName: String): Boolean
-
- /**
- * Recovers all the partitions in the directory of a table and update the catalog. Only works
- * with a partitioned table, and not a view.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table. If no database
- * identifier is provided, it refers to a table in the current database.
- * @since 3.5.0
- */
- def recoverPartitions(tableName: String): Unit
-
- /**
- * Returns true if the table is currently cached in-memory.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. If no database
- * identifier is provided, it refers to a temporary view or a table/view in the current
- * database.
- * @since 3.5.0
- */
- def isCached(tableName: String): Boolean
-
- /**
- * Caches the specified table in-memory.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. If no database
- * identifier is provided, it refers to a temporary view or a table/view in the current
- * database.
- * @since 3.5.0
- */
- def cacheTable(tableName: String): Unit
-
- /**
- * Caches the specified table with the given storage level.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. If no database
- * identifier is provided, it refers to a temporary view or a table/view in the current
- * database.
- * @param storageLevel
- * storage level to cache table.
- * @since 3.5.0
- */
- def cacheTable(tableName: String, storageLevel: StorageLevel): Unit
-
- /**
- * Removes the specified table from the in-memory cache.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. If no database
- * identifier is provided, it refers to a temporary view or a table/view in the current
- * database.
- * @since 3.5.0
- */
- def uncacheTable(tableName: String): Unit
-
- /**
- * Removes all cached tables from the in-memory cache.
- *
- * @since 3.5.0
- */
- def clearCache(): Unit
-
- /**
- * Invalidates and refreshes all the cached data and metadata of the given table. For
- * performance reasons, Spark SQL or the external data source library it uses might cache
- * certain metadata about a table, such as the location of blocks. When those change outside of
- * Spark SQL, users should call this function to invalidate the cache.
- *
- * If this table is cached as an InMemoryRelation, drop the original cached version and make the
- * new version cached lazily.
- *
- * @param tableName
- * is either a qualified or unqualified name that designates a table/view. If no database
- * identifier is provided, it refers to a temporary view or a table/view in the current
- * database.
- * @since 3.5.0
- */
- def refreshTable(tableName: String): Unit
-
- /**
- * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset`
- * that contains the given data source path. Path matching is by prefix, i.e. "/" would
- * invalidate everything that is cached.
- *
- * @since 3.5.0
- */
- def refreshByPath(path: String): Unit
-
- /**
- * Returns the current catalog in this session.
- *
- * @since 3.5.0
- */
- def currentCatalog(): String
-
- /**
- * Sets the current catalog in this session.
- *
- * @since 3.5.0
- */
- def setCurrentCatalog(catalogName: String): Unit
-
- /**
- * Returns a list of catalogs available in this session.
- *
- * @since 3.5.0
- */
- def listCatalogs(): Dataset[CatalogMetadata]
-
- /**
- * Returns a list of catalogs which name match the specify pattern and available in this
- * session.
- *
- * @since 3.5.0
- */
- def listCatalogs(pattern: String): Dataset[CatalogMetadata]
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, schema, description, options)
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala
new file mode 100644
index 0000000000000..7d81f4ead7857
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect
+
+import scala.language.implicitConversions
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql._
+
+/**
+ * Conversions from sql interfaces to the Connect specific implementation.
+ *
+ * This class is mainly used by the implementation. In the case of connect it should be extremely
+ * rare that a developer needs these classes.
+ *
+ * We provide both a trait and an object. The trait is useful in situations where an extension
+ * developer needs to use these conversions in a project covering multiple Spark versions. They
+ * can create a shim for these conversions, the Spark 4+ version of the shim implements this
+ * trait, and shims for older versions do not.
+ */
+@DeveloperApi
+trait ConnectConversions {
+ implicit def castToImpl(session: api.SparkSession): SparkSession =
+ session.asInstanceOf[SparkSession]
+
+ implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] =
+ ds.asInstanceOf[Dataset[T]]
+
+ implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset =
+ rgds.asInstanceOf[RelationalGroupedDataset]
+
+ implicit def castToImpl[K, V](
+ kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] =
+ kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]
+}
+
+object ConnectConversions extends ConnectConversions
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala
similarity index 68%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
rename to connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala
index f77dd512ef257..7578e2424fb42 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala
@@ -14,10 +14,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql
+package org.apache.spark.sql.internal
import org.apache.spark.connect.proto.{ConfigRequest, ConfigResponse, KeyValue}
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.connect.client.SparkConnectClient
/**
@@ -25,61 +26,31 @@ import org.apache.spark.sql.connect.client.SparkConnectClient
*
* @since 3.4.0
*/
-class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging {
+class ConnectRuntimeConfig private[sql] (client: SparkConnectClient)
+ extends RuntimeConfig
+ with Logging {
- /**
- * Sets the given Spark runtime configuration property.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def set(key: String, value: String): Unit = {
executeConfigRequest { builder =>
builder.getSetBuilder.addPairsBuilder().setKey(key).setValue(value)
}
}
- /**
- * Sets the given Spark runtime configuration property.
- *
- * @since 3.4.0
- */
- def set(key: String, value: Boolean): Unit = set(key, String.valueOf(value))
-
- /**
- * Sets the given Spark runtime configuration property.
- *
- * @since 3.4.0
- */
- def set(key: String, value: Long): Unit = set(key, String.valueOf(value))
-
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- *
- * @throws java.util.NoSuchElementException
- * if the key is not set and does not have a default value
- * @since 3.4.0
- */
+ /** @inheritdoc */
@throws[NoSuchElementException]("if the key is not set")
def get(key: String): String = getOption(key).getOrElse {
throw new NoSuchElementException(key)
}
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def get(key: String, default: String): String = {
executeConfigRequestSingleValue { builder =>
builder.getGetWithDefaultBuilder.addPairsBuilder().setKey(key).setValue(default)
}
}
- /**
- * Returns all properties set in this conf.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def getAll: Map[String, String] = {
val response = executeConfigRequest { builder =>
builder.getGetAllBuilder
@@ -92,11 +63,7 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging {
builder.result()
}
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def getOption(key: String): Option[String] = {
val pair = executeConfigRequestSinglePair { builder =>
builder.getGetOptionBuilder.addKeys(key)
@@ -108,27 +75,14 @@ class RuntimeConfig private[sql] (client: SparkConnectClient) extends Logging {
}
}
- /**
- * Resets the configuration property for the given key.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def unset(key: String): Unit = {
executeConfigRequest { builder =>
builder.getUnsetBuilder.addKeys(key)
}
}
- /**
- * Indicates whether the configuration property with the given key is modifiable in the current
- * session.
- *
- * @return
- * `true` if the configuration property is modifiable. For static SQL, Spark Core, invalid
- * (not existing) and other non-modifiable configuration properties, the returned value is
- * `false`.
- * @since 3.4.0
- */
+ /** @inheritdoc */
def isModifiable(key: String): Boolean = {
val modifiable = executeConfigRequestSingleValue { builder =>
builder.getIsModifiableBuilder.addKeys(key)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala
new file mode 100644
index 0000000000000..4afa8b6d566c5
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.{Column, DataFrameWriterV2, Dataset}
+
+/**
+ * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2
+ * API.
+ *
+ * @since 3.4.0
+ */
+@Experimental
+final class DataFrameWriterV2Impl[T] private[sql] (table: String, ds: Dataset[T])
+ extends DataFrameWriterV2[T] {
+ import ds.sparkSession.RichColumn
+
+ private val builder = proto.WriteOperationV2
+ .newBuilder()
+ .setInput(ds.plan.getRoot)
+ .setTableName(table)
+
+ /** @inheritdoc */
+ override def using(provider: String): this.type = {
+ builder.setProvider(provider)
+ this
+ }
+
+ /** @inheritdoc */
+ override def option(key: String, value: String): this.type = {
+ builder.putOptions(key, value)
+ this
+ }
+
+ /** @inheritdoc */
+ override def options(options: scala.collection.Map[String, String]): this.type = {
+ builder.putAllOptions(options.asJava)
+ this
+ }
+
+ /** @inheritdoc */
+ override def options(options: java.util.Map[String, String]): this.type = {
+ builder.putAllOptions(options)
+ this
+ }
+
+ /** @inheritdoc */
+ override def tableProperty(property: String, value: String): this.type = {
+ builder.putTableProperties(property, value)
+ this
+ }
+
+ /** @inheritdoc */
+ @scala.annotation.varargs
+ override def partitionedBy(column: Column, columns: Column*): this.type = {
+ builder.addAllPartitioningColumns((column +: columns).map(_.expr).asJava)
+ this
+ }
+
+ /** @inheritdoc */
+ @scala.annotation.varargs
+ override def clusterBy(colName: String, colNames: String*): this.type = {
+ builder.addAllClusteringColumns((colName +: colNames).asJava)
+ this
+ }
+
+ /** @inheritdoc */
+ override def create(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE)
+ }
+
+ /** @inheritdoc */
+ override def replace(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE)
+ }
+
+ /** @inheritdoc */
+ override def createOrReplace(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
+ }
+
+ /** @inheritdoc */
+ def append(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND)
+ }
+
+ /** @inheritdoc */
+ def overwrite(condition: Column): Unit = {
+ builder.setOverwriteCondition(condition.expr)
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
+ }
+
+ /** @inheritdoc */
+ def overwritePartitions(): Unit = {
+ executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS)
+ }
+
+ private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = {
+ val command = proto.Command
+ .newBuilder()
+ .setWriteOperationV2(builder.setMode(mode))
+ .build()
+ ds.sparkSession.execute(command)
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
new file mode 100644
index 0000000000000..fba3c6343558b
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{Expression, MergeAction, MergeIntoTableCommand}
+import org.apache.spark.connect.proto.MergeAction.ActionType._
+import org.apache.spark.sql.{Column, Dataset, MergeIntoWriter}
+import org.apache.spark.sql.functions.expr
+
+/**
+ * `MergeIntoWriter` provides methods to define and execute merge actions based on specified
+ * conditions.
+ *
+ * @tparam T
+ * the type of data in the Dataset.
+ * @param table
+ * the name of the target table for the merge operation.
+ * @param ds
+ * the source Dataset to merge into the target table.
+ * @param on
+ * the merge condition.
+ *
+ * @since 4.0.0
+ */
+@Experimental
+class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column)
+ extends MergeIntoWriter[T] {
+ import ds.sparkSession.RichColumn
+
+ private val builder = MergeIntoTableCommand
+ .newBuilder()
+ .setTargetTableName(table)
+ .setSourceTablePlan(ds.plan.getRoot)
+ .setMergeCondition(on.expr)
+
+ /**
+ * Executes the merge operation.
+ */
+ def merge(): Unit = {
+ if (builder.getMatchActionsCount == 0 &&
+ builder.getNotMatchedActionsCount == 0 &&
+ builder.getNotMatchedBySourceActionsCount == 0) {
+ throw new SparkRuntimeException(
+ errorClass = "NO_MERGE_ACTION_SPECIFIED",
+ messageParameters = Map.empty)
+ }
+ ds.sparkSession.execute(
+ proto.Command
+ .newBuilder()
+ .setMergeIntoTableCommand(builder.setWithSchemaEvolution(schemaEvolutionEnabled))
+ .build())
+ }
+
+ override protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] = {
+ builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT_STAR, condition))
+ this
+ }
+
+ override protected[sql] def insert(
+ condition: Option[Column],
+ map: Map[String, Column]): MergeIntoWriter[T] = {
+ builder.addNotMatchedActions(buildMergeAction(ACTION_TYPE_INSERT, condition, map))
+ this
+ }
+
+ override protected[sql] def updateAll(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(
+ buildMergeAction(ACTION_TYPE_UPDATE_STAR, condition),
+ notMatchedBySource)
+ }
+
+ override protected[sql] def update(
+ condition: Option[Column],
+ map: Map[String, Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(
+ buildMergeAction(ACTION_TYPE_UPDATE, condition, map),
+ notMatchedBySource)
+ }
+
+ override protected[sql] def delete(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(buildMergeAction(ACTION_TYPE_DELETE, condition), notMatchedBySource)
+ }
+
+ private def appendUpdateDeleteAction(
+ action: Expression,
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ if (notMatchedBySource) {
+ builder.addNotMatchedBySourceActions(action)
+ } else {
+ builder.addMatchActions(action)
+ }
+ this
+ }
+
+ private def buildMergeAction(
+ actionType: MergeAction.ActionType,
+ condition: Option[Column],
+ assignments: Map[String, Column] = Map.empty): Expression = {
+ val builder = proto.MergeAction.newBuilder().setActionType(actionType)
+ condition.foreach(c => builder.setCondition(c.expr))
+ assignments.foreach { case (k, v) =>
+ builder
+ .addAssignmentsBuilder()
+ .setKey(expr(k).expr)
+ .setValue(v.expr)
+ }
+ Expression
+ .newBuilder()
+ .setMergeAction(builder)
+ .build()
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
index 13a26fa79085e..29fbcc443deb9 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
@@ -18,166 +18,21 @@
package org.apache.spark.sql.streaming
import java.util.UUID
-import java.util.concurrent.TimeoutException
import scala.jdk.CollectionConverters._
-import org.apache.spark.annotation.Evolving
import org.apache.spark.connect.proto.Command
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.connect.proto.StreamingQueryCommand
import org.apache.spark.connect.proto.StreamingQueryCommandResult
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{api, SparkSession}
-/**
- * A handle to a query that is executing continuously in the background as new data arrives. All
- * these methods are thread-safe.
- * @since 3.5.0
- */
-@Evolving
-trait StreamingQuery {
- // This is a copy of StreamingQuery in sql/core/.../streaming/StreamingQuery.scala
-
- /**
- * Returns the user-specified name of the query, or null if not specified. This name can be
- * specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as
- * `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across
- * all active queries.
- *
- * @since 3.5.0
- */
- def name: String
-
- /**
- * Returns the unique id of this query that persists across restarts from checkpoint data. That
- * is, this id is generated when a query is started for the first time, and will be the same
- * every time it is restarted from checkpoint data. Also see [[runId]].
- *
- * @since 3.5.0
- */
- def id: UUID
-
- /**
- * Returns the unique id of this run of the query. That is, every start/restart of a query will
- * generate a unique runId. Therefore, every time a query is restarted from checkpoint, it will
- * have the same [[id]] but different [[runId]]s.
- */
- def runId: UUID
-
- /**
- * Returns the `SparkSession` associated with `this`.
- *
- * @since 3.5.0
- */
- def sparkSession: SparkSession
-
- /**
- * Returns `true` if this query is actively running.
- *
- * @since 3.5.0
- */
- def isActive: Boolean
-
- /**
- * Returns the [[StreamingQueryException]] if the query was terminated by an exception.
- * @since 3.5.0
- */
- def exception: Option[StreamingQueryException]
-
- /**
- * Returns the current status of the query.
- *
- * @since 3.5.0
- */
- def status: StreamingQueryStatus
-
- /**
- * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The
- * number of progress updates retained for each stream is configured by Spark session
- * configuration `spark.sql.streaming.numRecentProgressUpdates`.
- *
- * @since 3.5.0
- */
- def recentProgress: Array[StreamingQueryProgress]
-
- /**
- * Returns the most recent [[StreamingQueryProgress]] update of this streaming query.
- *
- * @since 3.5.0
- */
- def lastProgress: StreamingQueryProgress
-
- /**
- * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If
- * the query has terminated with an exception, then the exception will be thrown.
- *
- * If the query has terminated, then all subsequent calls to this method will either return
- * immediately (if the query was terminated by `stop()`), or throw the exception immediately (if
- * the query has terminated with exception).
- *
- * @throws StreamingQueryException
- * if the query has terminated with an exception.
- * @since 3.5.0
- */
- @throws[StreamingQueryException]
- def awaitTermination(): Unit
-
- /**
- * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If
- * the query has terminated with an exception, then the exception will be thrown. Otherwise, it
- * returns whether the query has terminated or not within the `timeoutMs` milliseconds.
- *
- * If the query has terminated, then all subsequent calls to this method will either return
- * `true` immediately (if the query was terminated by `stop()`), or throw the exception
- * immediately (if the query has terminated with exception).
- *
- * @throws StreamingQueryException
- * if the query has terminated with an exception
- * @since 3.5.0
- */
- @throws[StreamingQueryException]
- def awaitTermination(timeoutMs: Long): Boolean
-
- /**
- * Blocks until all available data in the source has been processed and committed to the sink.
- * This method is intended for testing. Note that in the case of continually arriving data, this
- * method may block forever. Additionally, this method is only guaranteed to block until data
- * that has been synchronously appended data to a
- * `org.apache.spark.sql.execution.streaming.Source` prior to invocation. (i.e. `getOffset` must
- * immediately reflect the addition).
- * @since 3.5.0
- */
- def processAllAvailable(): Unit
-
- /**
- * Stops the execution of this query if it is running. This waits until the termination of the
- * query execution threads or until a timeout is hit.
- *
- * By default stop will block indefinitely. You can configure a timeout by the configuration
- * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block
- * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the
- * issue persists, it is advisable to kill the Spark application.
- *
- * @since 3.5.0
- */
- @throws[TimeoutException]
- def stop(): Unit
-
- /**
- * Prints the physical plan to the console for debugging purposes.
- * @since 3.5.0
- */
- def explain(): Unit
-
- /**
- * Prints the physical plan to the console for debugging purposes.
- *
- * @param extended
- * whether to do extended explain or not
- * @since 3.5.0
- */
- def explain(extended: Boolean): Unit
+/** @inheritdoc */
+trait StreamingQuery extends api.StreamingQuery {
+
+ /** @inheritdoc */
+ override def sparkSession: SparkSession
}
class RemoteStreamingQuery(
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 0926734ef4872..4a76c380f772f 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -23,7 +23,7 @@ import java.util.Properties
import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
-import scala.concurrent.duration.DurationInt
+import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.jdk.CollectionConverters._
import org.apache.commons.io.FileUtils
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
+import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper}
@@ -1569,6 +1569,25 @@ class ClientE2ETestSuite
val result = df.select(trim(col("col"), " ").as("trimmed_col")).collect()
assert(result sameElements Array(Row("a"), Row("b"), Row("c")))
}
+
+ test("SPARK-49673: new batch size, multiple batches") {
+ val maxBatchSize = spark.conf.get("spark.connect.grpc.arrow.maxBatchSize").dropRight(1).toInt
+ // Adjust client grpcMaxMessageSize to maxBatchSize (10MiB; set in RemoteSparkSession config)
+ val sparkWithLowerMaxMessageSize = SparkSession
+ .builder()
+ .client(
+ SparkConnectClient
+ .builder()
+ .userId("test")
+ .port(port)
+ .grpcMaxMessageSize(maxBatchSize)
+ .retryPolicy(RetryPolicy
+ .defaultPolicy()
+ .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s"))))
+ .build())
+ .create()
+ assert(sparkWithLowerMaxMessageSize.range(maxBatchSize).collect().length == maxBatchSize)
+ }
}
private[sql] case class ClassData(a: String, b: Int)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 474eac138ab78..315f80e13eff7 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -552,6 +552,14 @@ class PlanGenerationTestSuite
valueColumnName = "value")
}
+ test("transpose index_column") {
+ simple.transpose(indexColumn = fn.col("id"))
+ }
+
+ test("transpose no_index_column") {
+ simple.transpose()
+ }
+
test("offset") {
simple.offset(1000)
}
@@ -1801,7 +1809,11 @@ class PlanGenerationTestSuite
fn.sentences(fn.col("g"))
}
- functionTest("sentences with locale") {
+ functionTest("sentences with language") {
+ fn.sentences(fn.col("g"), lit("en"))
+ }
+
+ functionTest("sentences with language and country") {
fn.sentences(fn.col("g"), lit("en"), lit("US"))
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 10b31155376fb..16f6983efb187 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -145,21 +145,6 @@ object CheckConnectJvmClientCompatibility {
checkMiMaCompatibility(clientJar, protobufJar, includedRules, excludeRules)
}
- private lazy val mergeIntoWriterExcludeRules: Seq[ProblemFilter] = {
- // Exclude some auto-generated methods in [[MergeIntoWriter]] classes.
- // The incompatible changes are due to the uses of [[proto.Expression]] instead
- // of [[catalyst.Expression]] in the method signature.
- val classNames = Seq("WhenMatched", "WhenNotMatched", "WhenNotMatchedBySource")
- val methodNames = Seq("apply", "condition", "copy", "copy$*", "unapply")
-
- classNames.flatMap { className =>
- methodNames.map { methodName =>
- ProblemFilters.exclude[IncompatibleSignatureProblem](
- s"org.apache.spark.sql.$className.$methodName")
- }
- }
- }
-
private def checkMiMaCompatibilityWithSqlModule(
clientJar: File,
sqlJar: File): List[Problem] = {
@@ -173,6 +158,7 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"),
+ ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"),
@@ -269,12 +255,6 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.streaming.TestGroupState"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.TestGroupState$"),
- ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.streaming.PythonStreamingQueryListener"),
- ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.streaming.PythonStreamingQueryListenerWrapper"),
- ProblemFilters.exclude[MissingTypesProblem](
- "org.apache.spark.sql.streaming.StreamingQueryListener$Event"),
// SQLImplicits
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"),
@@ -286,10 +266,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.ArtifactManager$"),
- // UDFRegistration
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.UDFRegistration.register"),
-
// ColumnNode conversions
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.Converter"),
@@ -304,6 +280,8 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.*"),
// UDFRegistration
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.UDFRegistration.register"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.UDFRegistration"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.log*"),
@@ -320,12 +298,13 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.initializeLogIfNecessary$default$2"),
- // Datasource V2 partition transforms
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"),
- ProblemFilters.exclude[MissingClassProblem](
- "org.apache.spark.sql.PartitionTransform$ExtractTransform")) ++
- mergeIntoWriterExcludeRules
+ // Protected DataFrameReader methods...
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.DataFrameReader.validateSingleVariantColumn"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.DataFrameReader.validateJsonSchema"),
+ ProblemFilters.exclude[DirectMissingMethodProblem](
+ "org.apache.spark.sql.DataFrameReader.validateXmlSchema"))
checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules)
}
@@ -387,21 +366,6 @@ object CheckConnectJvmClientCompatibility {
// Experimental
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.registerClassFinder"),
- // public
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.interruptAll"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.interruptTag"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.interruptOperation"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.addTag"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.removeTag"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.getTags"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.sql.SparkSession.clearTags"),
// SparkSession#Builder
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession#Builder.remote"),
@@ -435,8 +399,7 @@ object CheckConnectJvmClientCompatibility {
// Encoders are in the wrong JAR
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) ++
- mergeIntoWriterExcludeRules
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$"))
checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules)
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 70b471cf74b33..5397dae9dcc5f 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -30,7 +30,7 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
import org.apache.arrow.vector.VarBinaryVector
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.{sql, SparkUnsupportedOperationException}
+import org.apache.spark.{sql, SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.sql.{AnalysisException, Encoders, Row}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes}
@@ -776,6 +776,16 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
}
}
+ test("kryo serialization") {
+ val e = intercept[SparkRuntimeException] {
+ val encoder = sql.encoderFor(Encoders.kryo[(Int, String)])
+ roundTripAndCheckIdentical(encoder) { () =>
+ Iterator.tabulate(10)(i => (i, "itr_" + i))
+ }
+ }
+ assert(e.getErrorClass == "CANNOT_USE_KRYO")
+ }
+
test("transforming encoder") {
val schema = new StructType()
.add("key", IntegerType)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index 758262ead7f1e..27b1ee014a719 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -334,8 +334,6 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
assert(exception.getErrorClass != null)
assert(exception.getMessageParameters().get("id") == query.id.toString)
assert(exception.getMessageParameters().get("runId") == query.runId.toString)
- assert(!exception.getMessageParameters().get("startOffset").isEmpty)
- assert(!exception.getMessageParameters().get("endOffset").isEmpty)
assert(exception.getCause.isInstanceOf[SparkException])
assert(exception.getCause.getCause.isInstanceOf[SparkException])
assert(
@@ -374,8 +372,6 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L
assert(exception.getErrorClass != null)
assert(exception.getMessageParameters().get("id") == query.id.toString)
assert(exception.getMessageParameters().get("runId") == query.runId.toString)
- assert(!exception.getMessageParameters().get("startOffset").isEmpty)
- assert(!exception.getMessageParameters().get("endOffset").isEmpty)
assert(exception.getCause.isInstanceOf[SparkException])
assert(exception.getCause.getCause.isInstanceOf[SparkException])
assert(
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
index a878e42b40aa7..36aaa2cc7fbf6 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
@@ -24,6 +24,9 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration.FiniteDuration
import org.scalatest.{BeforeAndAfterAll, Suite}
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.timeout
+import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkBuildInfo
import org.apache.spark.sql.SparkSession
@@ -121,6 +124,8 @@ object SparkConnectServerUtils {
// to make the tests exercise reattach.
"spark.connect.execute.reattachable.senderMaxStreamDuration=1s",
"spark.connect.execute.reattachable.senderMaxStreamSize=123",
+ // Testing SPARK-49673, setting maxBatchSize to 10MiB
+ s"spark.connect.grpc.arrow.maxBatchSize=${10 * 1024 * 1024}",
// Disable UI
"spark.ui.enabled=false")
Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil)
@@ -184,12 +189,14 @@ object SparkConnectServerUtils {
.port(port)
.retryPolicy(RetryPolicy
.defaultPolicy()
- .copy(maxRetries = Some(7), maxBackoff = Some(FiniteDuration(10, "s"))))
+ .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s"))))
.build())
.create()
// Execute an RPC which will get retried until the server is up.
- assert(spark.version == SparkBuildInfo.spark_version)
+ eventually(timeout(1.minute)) {
+ assert(spark.version == SparkBuildInfo.spark_version)
+ }
// Auto-sync dependencies.
SparkConnectServerUtils.syncTestDependencies(spark)
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index 30009c03c49fd..90cd68e6e1d24 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -490,7 +490,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
.option("query", "SELECT @myvariant1 as variant1, @myvariant2 as variant2")
.load()
},
- errorClass = "UNRECOGNIZED_SQL_TYPE",
+ condition = "UNRECOGNIZED_SQL_TYPE",
parameters = Map("typeName" -> "sql_variant", "jdbcType" -> "-156"))
}
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
index b337eb2fc9b3b..91a82075a3607 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
@@ -87,7 +87,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
exception = intercept[AnalysisException] {
sql(sql1)
},
- errorClass = "NOT_SUPPORTED_CHANGE_COLUMN",
+ condition = "NOT_SUPPORTED_CHANGE_COLUMN",
parameters = Map(
"originType" -> "\"DOUBLE\"",
"newType" -> "\"VARCHAR(10)\"",
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index 27ec98e9ac451..e5fd453cb057c 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -97,7 +97,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
exception = intercept[AnalysisException] {
sql(sql1)
},
- errorClass = "NOT_SUPPORTED_CHANGE_COLUMN",
+ condition = "NOT_SUPPORTED_CHANGE_COLUMN",
parameters = Map(
"originType" -> "\"STRING\"",
"newType" -> "\"INT\"",
@@ -115,7 +115,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
exception = intercept[SparkSQLFeatureNotSupportedException] {
sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL")
},
- errorClass = "_LEGACY_ERROR_TEMP_2271")
+ condition = "_LEGACY_ERROR_TEMP_2271")
}
test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") {
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
index 81aacf2c14d7a..700c05b54a256 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
@@ -77,8 +77,19 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
s"""CREATE TABLE pattern_testing_table (
|pattern_testing_col LONGTEXT
|)
- """.stripMargin
+ |""".stripMargin
).executeUpdate()
+ connection.prepareStatement(
+ "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)")
+ .executeUpdate()
+ }
+
+ override def dataPreparation(connection: Connection): Unit = {
+ super.dataPreparation(connection)
+ connection.prepareStatement("INSERT INTO datetime VALUES " +
+ "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate()
+ connection.prepareStatement("INSERT INTO datetime VALUES " +
+ "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate()
}
override def testUpdateColumnType(tbl: String): Unit = {
@@ -98,7 +109,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
exception = intercept[AnalysisException] {
sql(sql1)
},
- errorClass = "NOT_SUPPORTED_CHANGE_COLUMN",
+ condition = "NOT_SUPPORTED_CHANGE_COLUMN",
parameters = Map(
"originType" -> "\"STRING\"",
"newType" -> "\"INT\"",
@@ -131,7 +142,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
exception = intercept[SparkSQLFeatureNotSupportedException] {
sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL")
},
- errorClass = "_LEGACY_ERROR_TEMP_2271")
+ condition = "_LEGACY_ERROR_TEMP_2271")
}
override def testCreateTableWithProperty(tbl: String): Unit = {
@@ -157,6 +168,79 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
assert(sql(s"SELECT char_length(c1) from $tableName").head().get(0) === 65536)
}
}
+
+ override def testDatetime(tbl: String): Unit = {
+ val df1 = sql(s"SELECT name FROM $tbl WHERE " +
+ "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ")
+ checkFilterPushed(df1)
+ val rows1 = df1.collect()
+ assert(rows1.length === 2)
+ assert(rows1(0).getString(0) === "amy")
+ assert(rows1(1).getString(0) === "alex")
+
+ val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2")
+ checkFilterPushed(df2)
+ val rows2 = df2.collect()
+ assert(rows2.length === 2)
+ assert(rows2(0).getString(0) === "amy")
+ assert(rows2(1).getString(0) === "alex")
+
+ val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5")
+ checkFilterPushed(df3)
+ val rows3 = df3.collect()
+ assert(rows3.length === 2)
+ assert(rows3(0).getString(0) === "amy")
+ assert(rows3(1).getString(0) === "alex")
+
+ val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0")
+ checkFilterPushed(df4)
+ val rows4 = df4.collect()
+ assert(rows4.length === 2)
+ assert(rows4(0).getString(0) === "amy")
+ assert(rows4(1).getString(0) === "alex")
+
+ val df5 = sql(s"SELECT name FROM $tbl WHERE " +
+ "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022")
+ checkFilterPushed(df5)
+ val rows5 = df5.collect()
+ assert(rows5.length === 2)
+ assert(rows5(0).getString(0) === "amy")
+ assert(rows5(1).getString(0) === "alex")
+
+ val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " +
+ "AND datediff(date1, '2022-05-10') > 0")
+ checkFilterPushed(df6)
+ val rows6 = df6.collect()
+ assert(rows6.length === 1)
+ assert(rows6(0).getString(0) === "amy")
+
+ val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2")
+ checkFilterPushed(df7)
+ val rows7 = df7.collect()
+ assert(rows7.length === 1)
+ assert(rows7(0).getString(0) === "alex")
+
+ val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4")
+ checkFilterPushed(df8)
+ val rows8 = df8.collect()
+ assert(rows8.length === 1)
+ assert(rows8(0).getString(0) === "alex")
+
+ val df9 = sql(s"SELECT name FROM $tbl WHERE " +
+ "dayofyear(date1) > 100 order by dayofyear(date1) limit 1")
+ checkFilterPushed(df9)
+ val rows9 = df9.collect()
+ assert(rows9.length === 1)
+ assert(rows9(0).getString(0) === "alex")
+
+ // MySQL does not support
+ val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'")
+ checkFilterPushed(df10, false)
+ val rows10 = df10.collect()
+ assert(rows10.length === 2)
+ assert(rows10(0).getString(0) === "amy")
+ assert(rows10(1).getString(0) === "alex")
+ }
}
/**
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala
index 42d82233b421b..5e40f0bbc4554 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala
@@ -62,7 +62,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac
exception = intercept[SparkSQLFeatureNotSupportedException] {
catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava)
},
- errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE",
+ condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE",
parameters = Map("namespace" -> "`foo`")
)
assert(catalog.namespaceExists(Array("foo")) === false)
@@ -74,7 +74,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac
Array("foo"),
NamespaceChange.setProperty("comment", "comment for foo"))
},
- errorClass = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE",
+ condition = "UNSUPPORTED_FEATURE.COMMENT_NAMESPACE",
parameters = Map("namespace" -> "`foo`")
)
@@ -82,7 +82,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac
exception = intercept[SparkSQLFeatureNotSupportedException] {
catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment"))
},
- errorClass = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT",
+ condition = "UNSUPPORTED_FEATURE.REMOVE_NAMESPACE_COMMENT",
parameters = Map("namespace" -> "`foo`")
)
@@ -90,7 +90,7 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac
exception = intercept[SparkSQLFeatureNotSupportedException] {
catalog.dropNamespace(Array("foo"), cascade = false)
},
- errorClass = "UNSUPPORTED_FEATURE.DROP_NAMESPACE",
+ condition = "UNSUPPORTED_FEATURE.DROP_NAMESPACE",
parameters = Map("namespace" -> "`foo`")
)
catalog.dropNamespace(Array("foo"), cascade = true)
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
index 342fb4bb38e60..2c97a588670a8 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
@@ -118,7 +118,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
exception = intercept[AnalysisException] {
sql(sql1)
},
- errorClass = "NOT_SUPPORTED_CHANGE_COLUMN",
+ condition = "NOT_SUPPORTED_CHANGE_COLUMN",
parameters = Map(
"originType" -> "\"DECIMAL(19,0)\"",
"newType" -> "\"INT\"",
@@ -139,7 +139,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
exception = intercept[SparkRuntimeException] {
sql(s"INSERT INTO $tableName SELECT rpad('hi', 256, 'spark')")
},
- errorClass = "EXCEED_LIMIT_LENGTH",
+ condition = "EXCEED_LIMIT_LENGTH",
parameters = Map("limit" -> "255")
)
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index e22136a09a56c..850391e8dc33c 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -84,7 +84,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
exception = intercept[AnalysisException] {
sql(sql1)
},
- errorClass = "NOT_SUPPORTED_CHANGE_COLUMN",
+ condition = "NOT_SUPPORTED_CHANGE_COLUMN",
parameters = Map(
"originType" -> "\"STRING\"",
"newType" -> "\"INT\"",
@@ -118,7 +118,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
sql(s"CREATE TABLE $t2(c int)")
checkError(
exception = intercept[TableAlreadyExistsException](sql(s"ALTER TABLE $t1 RENAME TO t2")),
- errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
+ condition = "TABLE_OR_VIEW_ALREADY_EXISTS",
parameters = Map("relationName" -> "`t2`")
)
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala
index e4cc88cec0f5e..3b1a457214be7 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala
@@ -92,7 +92,7 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte
catalog.listNamespaces(Array("foo"))
}
checkError(e,
- errorClass = "SCHEMA_NOT_FOUND",
+ condition = "SCHEMA_NOT_FOUND",
parameters = Map("schemaName" -> "`foo`"))
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index b0ab614b27d1f..54635f69f8b65 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -71,7 +71,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[AnalysisException] {
sql(sqlText)
},
- errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
sqlState = "42703",
parameters = Map(
"objectName" -> "`bad_column`",
@@ -92,11 +92,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
private def checkErrorFailedJDBC(
e: AnalysisException,
- errorClass: String,
+ condition: String,
tbl: String): Unit = {
checkErrorMatchPVals(
exception = e,
- errorClass = errorClass,
+ condition = condition,
parameters = Map(
"url" -> "jdbc:.*",
"tableName" -> s"`$tbl`")
@@ -126,7 +126,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.alt_table ADD COLUMNS (C3 DOUBLE)")
},
- errorClass = "FIELD_ALREADY_EXISTS",
+ condition = "FIELD_ALREADY_EXISTS",
parameters = Map(
"op" -> "add",
"fieldNames" -> "`C3`",
@@ -159,7 +159,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[AnalysisException] {
sql(sqlText)
},
- errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
sqlState = "42703",
parameters = Map(
"objectName" -> "`bad_column`",
@@ -182,7 +182,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[AnalysisException] {
sql(sqlText)
},
- errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ condition = "UNRESOLVED_COLUMN.WITH_SUGGESTION",
sqlState = "42703",
parameters = Map(
"objectName" -> "`bad_column`",
@@ -206,7 +206,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.alt_table RENAME COLUMN ID1 TO ID2")
},
- errorClass = "FIELD_ALREADY_EXISTS",
+ condition = "FIELD_ALREADY_EXISTS",
parameters = Map(
"op" -> "rename",
"fieldNames" -> "`ID2`",
@@ -308,7 +308,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[IndexAlreadyExistsException] {
sql(s"CREATE index i1 ON $catalogName.new_table (col1)")
},
- errorClass = "INDEX_ALREADY_EXISTS",
+ condition = "INDEX_ALREADY_EXISTS",
parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`")
)
@@ -333,7 +333,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[NoSuchIndexException] {
sql(s"DROP index i1 ON $catalogName.new_table")
},
- errorClass = "INDEX_NOT_FOUND",
+ condition = "INDEX_NOT_FOUND",
parameters = Map("indexName" -> "`i1`", "tableName" -> "`new_table`")
)
}
@@ -353,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}
- private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = {
+ protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = {
val filter = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
@@ -975,9 +975,15 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
exception = intercept[AnalysisException] {
sql(s"ALTER TABLE $catalogName.tbl2 RENAME TO tbl1")
},
- errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
+ condition = "TABLE_OR_VIEW_ALREADY_EXISTS",
parameters = Map("relationName" -> "`tbl1`")
)
}
}
+
+ def testDatetime(tbl: String): Unit = {}
+
+ test("scan with filter push-down with date time functions") {
+ testDatetime(s"$catalogAndNamespace.${caseConvert("datetime")}")
+ }
}
diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala
index 56456f9b1f776..8d0bcc5816775 100644
--- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala
+++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRecordToRowConverter.scala
@@ -50,7 +50,7 @@ private[kafka010] class KafkaRecordToRowConverter {
new GenericArrayData(cr.headers.iterator().asScala
.map(header =>
InternalRow(UTF8String.fromString(header.key()), header.value())
- ).toArray)
+ ).toArray[Any])
} else {
null
}
diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 9ae6a9290f80a..1d119de43970f 100644
--- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -1156,7 +1156,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with
test("allow group.id prefix") {
// Group ID prefix is only supported by consumer based offset reader
- if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) {
+ if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) {
testGroupId("groupIdPrefix", (expected, actual) => {
assert(actual.exists(_.startsWith(expected)) && !actual.exists(_ === expected),
"Valid consumer groups don't contain the expected group id - " +
@@ -1167,7 +1167,7 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase with
test("allow group.id override") {
// Group ID override is only supported by consumer based offset reader
- if (spark.conf.get(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) {
+ if (sqlConf.getConf(SQLConf.USE_DEPRECATED_KAFKA_OFFSET_FETCHING)) {
testGroupId("kafka.group.id", (expected, actual) => {
assert(actual.exists(_ === expected), "Valid consumer groups don't " +
s"contain the expected group id - Valid consumer groups: $actual / " +
diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala
index 320485a79e59d..6fc22e7ac5e03 100644
--- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala
+++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderSuite.scala
@@ -153,7 +153,7 @@ class KafkaOffsetReaderSuite extends QueryTest with SharedSparkSession with Kafk
}
checkError(
exception = ex,
- errorClass = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED",
+ condition = "KAFKA_START_OFFSET_DOES_NOT_MATCH_ASSIGNED",
parameters = Map(
"specifiedPartitions" -> "Set\\(.*,.*\\)",
"assignedPartitions" -> "Set\\(.*,.*,.*\\)"),
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
index 31050887936bd..3b0def8fc73f7 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala
@@ -20,7 +20,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.Column
-import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.protobuf.utils.ProtobufUtils
// scalastyle:off: object.name
@@ -71,7 +71,13 @@ object functions {
messageName: String,
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]): Column = {
- ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap)
+ Column.fnWithOptions(
+ "from_protobuf",
+ options.asScala.iterator,
+ data,
+ lit(messageName),
+ lit(binaryFileDescriptorSet)
+ )
}
/**
@@ -90,7 +96,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
- ProtobufDataToCatalyst(data, messageName, Some(fileContent))
+ from_protobuf(data, messageName, fileContent)
}
/**
@@ -109,7 +115,12 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
- ProtobufDataToCatalyst(data, messageName, Some(binaryFileDescriptorSet))
+ Column.fn(
+ "from_protobuf",
+ data,
+ lit(messageName),
+ lit(binaryFileDescriptorSet)
+ )
}
/**
@@ -129,7 +140,11 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageClassName: String): Column = {
- ProtobufDataToCatalyst(data, messageClassName)
+ Column.fn(
+ "from_protobuf",
+ data,
+ lit(messageClassName)
+ )
}
/**
@@ -153,7 +168,12 @@ object functions {
data: Column,
messageClassName: String,
options: java.util.Map[String, String]): Column = {
- ProtobufDataToCatalyst(data, messageClassName, None, options.asScala.toMap)
+ Column.fnWithOptions(
+ "from_protobuf",
+ options.asScala.iterator,
+ data,
+ lit(messageClassName)
+ )
}
/**
@@ -191,7 +211,12 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
- CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet))
+ Column.fn(
+ "to_protobuf",
+ data,
+ lit(messageName),
+ lit(binaryFileDescriptorSet)
+ )
}
/**
* Converts a column into binary of protobuf format. The Protobuf definition is provided
@@ -213,7 +238,7 @@ object functions {
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
- CatalystDataToProtobuf(data, messageName, Some(fileContent), options.asScala.toMap)
+ to_protobuf(data, messageName, fileContent, options)
}
/**
@@ -237,7 +262,13 @@ object functions {
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]
): Column = {
- CatalystDataToProtobuf(data, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap)
+ Column.fnWithOptions(
+ "to_protobuf",
+ options.asScala.iterator,
+ data,
+ lit(messageName),
+ lit(binaryFileDescriptorSet)
+ )
}
/**
@@ -257,7 +288,11 @@ object functions {
*/
@Experimental
def to_protobuf(data: Column, messageClassName: String): Column = {
- CatalystDataToProtobuf(data, messageClassName)
+ Column.fn(
+ "to_protobuf",
+ data,
+ lit(messageClassName)
+ )
}
/**
@@ -279,6 +314,11 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String])
: Column = {
- CatalystDataToProtobuf(data, messageClassName, None, options.asScala.toMap)
+ Column.fnWithOptions(
+ "to_protobuf",
+ options.asScala.iterator,
+ data,
+ lit(messageClassName)
+ )
}
}
diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
index 6644bce98293b..e85097a272f24 100644
--- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
+++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala
@@ -43,8 +43,8 @@ private[sql] class ProtobufOptions(
/**
* Adds support for recursive fields. If this option is is not specified, recursive fields are
- * not permitted. Setting it to 0 drops the recursive fields, 1 allows it to be recursed once,
- * and 2 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not
+ * not permitted. Setting it to 1 drops the recursive fields, 0 allows it to be recursed once,
+ * and 3 allows it to be recursed twice and so on, up to 10. Values larger than 10 are not
* allowed in order avoid inadvertently creating very large schemas. If a Protobuf message
* has depth beyond this limit, the Spark struct returned is truncated after the recursion limit.
*
@@ -52,8 +52,8 @@ private[sql] class ProtobufOptions(
* `message Person { string name = 1; Person friend = 2; }`
* The following lists the schema with different values for this setting.
* 1: `struct`
- * 2: `struct>`
- * 3: `struct>>`
+ * 2: `struct>`
+ * 3: `struct>>`
* and so on.
*/
val recursiveFieldMaxDepth: Int = parameters.getOrElse("recursive.fields.max.depth", "-1").toInt
@@ -181,7 +181,7 @@ private[sql] class ProtobufOptions(
val upcastUnsignedInts: Boolean =
parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean
- // Whether to unwrap the struct representation for well known primitve wrapper types when
+ // Whether to unwrap the struct representation for well known primitive wrapper types when
// deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value,
// google.protobuf.Int64Value, etc.) will get deserialized as structs. We allow the option to
// deserialize them as their respective primitives.
@@ -221,7 +221,7 @@ private[sql] class ProtobufOptions(
// By default, in the spark schema field a will be dropped, which result in schema
// b struct
// If retain.empty.message.types=true, field a will be retained by inserting a dummy column.
- // b struct, name: string>
+ // b struct, name: string>
val retainEmptyMessage: Boolean =
parameters.getOrElse("retain.empty.message.types", false.toString).toBoolean
}
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
index 4ff432cf7a055..3eaa91e472c43 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala
@@ -708,7 +708,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
checkError(
exception = e,
- errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND",
+ condition = "PROTOBUF_DEPENDENCY_NOT_FOUND",
parameters = Map("dependencyName" -> "nestedenum.proto"))
}
@@ -1057,7 +1057,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
checkError(
ex,
- errorClass = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND",
+ condition = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND",
parameters = Map("filePath" -> "/non/existent/path.desc")
)
assert(ex.getCause != null)
@@ -1699,7 +1699,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
checkError(
exception = parseError,
- errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE",
+ condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE",
parameters = Map(
"sqlColumn" -> "`basic_enum`",
"protobufColumn" -> "field 'basic_enum'",
@@ -1711,7 +1711,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
checkError(
exception = parseError,
- errorClass = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE",
+ condition = "CANNOT_CONVERT_SQL_VALUE_TO_PROTOBUF_ENUM_TYPE",
parameters = Map(
"sqlColumn" -> "`basic_enum`",
"protobufColumn" -> "field 'basic_enum'",
@@ -2093,7 +2093,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
| to_protobuf(complex_struct, 42, '$testFileDescFile', map())
|FROM protobuf_test_table
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" -> s"""\"to_protobuf(complex_struct, 42, $testFileDescFile, map())\"""",
"msg" -> ("The second argument of the TO_PROTOBUF SQL function must be a constant " +
@@ -2111,11 +2111,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
| to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map())
|FROM protobuf_test_table
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" -> "\"to_protobuf(complex_struct, SimpleMessageJavaTypes, 42, map())\"",
"msg" -> ("The third argument of the TO_PROTOBUF SQL function must be a constant " +
- "string representing the Protobuf descriptor file path"),
+ "string or binary data representing the Protobuf descriptor file path"),
"hint" -> ""),
queryContext = Array(ExpectedContext(
fragment = "to_protobuf(complex_struct, 'SimpleMessageJavaTypes', 42, map())",
@@ -2130,7 +2130,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
| to_protobuf(complex_struct, 'SimpleMessageJavaTypes', '$testFileDescFile', 42)
|FROM protobuf_test_table
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" ->
s"""\"to_protobuf(complex_struct, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""",
@@ -2152,7 +2152,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
|SELECT from_protobuf(protobuf_data, 42, '$testFileDescFile', map())
|FROM ($toProtobufSql)
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" -> s"""\"from_protobuf(protobuf_data, 42, $testFileDescFile, map())\"""",
"msg" -> ("The second argument of the FROM_PROTOBUF SQL function must be a constant " +
@@ -2169,11 +2169,11 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
|SELECT from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map())
|FROM ($toProtobufSql)
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" -> "\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, 42, map())\"",
"msg" -> ("The third argument of the FROM_PROTOBUF SQL function must be a constant " +
- "string representing the Protobuf descriptor file path"),
+ "string or binary data representing the Protobuf descriptor file path"),
"hint" -> ""),
queryContext = Array(ExpectedContext(
fragment = "from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', 42, map())",
@@ -2188,7 +2188,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
| from_protobuf(protobuf_data, 'SimpleMessageJavaTypes', '$testFileDescFile', 42)
|FROM ($toProtobufSql)
|""".stripMargin)),
- errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" ->
s"""\"from_protobuf(protobuf_data, SimpleMessageJavaTypes, $testFileDescFile, 42)\"""",
diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
index 03285c73f1ff1..2737bb9feb3ad 100644
--- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
+++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala
@@ -95,7 +95,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
protoFile,
Deserializer,
fieldMatch,
- errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE",
+ condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE",
params = Map(
"protobufType" -> "MissMatchTypeInRoot",
"toType" -> toSQLType(CATALYST_STRUCT)))
@@ -104,7 +104,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
protoFile,
Serializer,
fieldMatch,
- errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+ condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
params = Map(
"protobufType" -> "MissMatchTypeInRoot",
"toType" -> toSQLType(CATALYST_STRUCT)))
@@ -122,7 +122,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
protoFile,
Serializer,
BY_NAME,
- errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+ condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
params = Map(
"protobufType" -> "FieldMissingInProto",
"toType" -> toSQLType(CATALYST_STRUCT)))
@@ -132,7 +132,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
Serializer,
BY_NAME,
nonnullCatalyst,
- errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+ condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
params = Map(
"protobufType" -> "FieldMissingInProto",
"toType" -> toSQLType(nonnullCatalyst)))
@@ -150,7 +150,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
Deserializer,
fieldMatch,
catalyst,
- errorClass = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE",
+ condition = "CANNOT_CONVERT_PROTOBUF_MESSAGE_TYPE_TO_SQL_TYPE",
params = Map(
"protobufType" -> "MissMatchTypeInDeepNested",
"toType" -> toSQLType(catalyst)))
@@ -160,7 +160,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
Serializer,
fieldMatch,
catalyst,
- errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+ condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
params = Map(
"protobufType" -> "MissMatchTypeInDeepNested",
"toType" -> toSQLType(catalyst)))
@@ -177,7 +177,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
Serializer,
BY_NAME,
catalystSchema = foobarSQLType,
- errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+ condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
params = Map(
"protobufType" -> "FoobarWithRequiredFieldBar",
"toType" -> toSQLType(foobarSQLType)))
@@ -199,7 +199,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
Serializer,
BY_NAME,
catalystSchema = nestedFoobarSQLType,
- errorClass = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
+ condition = "UNABLE_TO_CONVERT_TO_PROTOBUF_MESSAGE_TYPE",
params = Map(
"protobufType" -> "NestedFoobarWithRequiredFieldBar",
"toType" -> toSQLType(nestedFoobarSQLType)))
@@ -222,7 +222,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
checkError(
exception = e1,
- errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR")
+ condition = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR")
val basicMessageDescWithoutImports = descriptorSetWithoutImports(
ProtobufUtils.readDescriptorFileContent(
@@ -240,7 +240,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
checkError(
exception = e2,
- errorClass = "PROTOBUF_DEPENDENCY_NOT_FOUND",
+ condition = "PROTOBUF_DEPENDENCY_NOT_FOUND",
parameters = Map("dependencyName" -> "nestedenum.proto"))
}
@@ -254,7 +254,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
serdeFactory: SerdeFactory[_],
fieldMatchType: MatchType,
catalystSchema: StructType = CATALYST_STRUCT,
- errorClass: String,
+ condition: String,
params: Map[String, String]): Unit = {
val e = intercept[AnalysisException] {
@@ -274,7 +274,7 @@ class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestBase {
assert(e.getMessage === expectMsg)
checkError(
exception = e,
- errorClass = errorClass,
+ condition = condition,
parameters = params)
}
diff --git a/core/pom.xml b/core/pom.xml
index 0a339e11a5d20..19f58940ed942 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -32,7 +32,6 @@
core
- **/OpenTelemetry*.scala
@@ -122,19 +121,14 @@
io.jsonwebtokenjjwt-api
- 0.12.6io.jsonwebtokenjjwt-impl
- 0.12.6
- testio.jsonwebtokenjjwt-jackson
- 0.12.6
- test
@@ -627,34 +613,10 @@
- opentelemetry
+ jjwt
-
+ compile
-
-
- io.opentelemetry
- opentelemetry-exporter-otlp
- 1.41.0
-
-
- io.opentelemetry
- opentelemetry-sdk-extension-autoconfigure-spi
-
-
-
-
- io.opentelemetry
- opentelemetry-sdk
- 1.41.0
-
-
- com.squareup.okhttp3
- okhttp
- 3.12.12
- test
-
- sparkr
diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
index 4e251a1c2901b..412d612c7f1d5 100644
--- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
+++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
@@ -17,6 +17,7 @@
package org.apache.spark.io;
import org.apache.spark.storage.StorageUtils;
+import org.apache.spark.unsafe.Platform;
import java.io.File;
import java.io.IOException;
@@ -47,7 +48,7 @@ public final class NioBufferedFileInputStream extends InputStream {
private final FileChannel fileChannel;
public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException {
- byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes);
+ byteBuffer = Platform.allocateDirectBuffer(bufferSizeInBytes);
fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ);
byteBuffer.flip();
this.cleanable = CLEANER.register(this, new ResourceCleaner(fileChannel, byteBuffer));
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 485f0abcd25ee..042179d86c31a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -27,6 +27,7 @@ import scala.collection.Map
import scala.collection.concurrent.{Map => ScalaConcurrentMap}
import scala.collection.immutable
import scala.collection.mutable.HashMap
+import scala.concurrent.{Future, Promise}
import scala.jdk.CollectionConverters._
import scala.reflect.{classTag, ClassTag}
import scala.util.control.NonFatal
@@ -909,10 +910,20 @@ class SparkContext(config: SparkConf) extends Logging {
*
* @since 3.5.0
*/
- def addJobTag(tag: String): Unit = {
- SparkContext.throwIfInvalidTag(tag)
+ def addJobTag(tag: String): Unit = addJobTags(Set(tag))
+
+ /**
+ * Add multiple tags to be assigned to all the jobs started by this thread.
+ * See [[addJobTag]] for more details.
+ *
+ * @param tags The tags to be added. Cannot contain ',' (comma) character.
+ *
+ * @since 4.0.0
+ */
+ def addJobTags(tags: Set[String]): Unit = {
+ tags.foreach(SparkContext.throwIfInvalidTag)
val existingTags = getJobTags()
- val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+ val newTags = (existingTags ++ tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags)
}
@@ -924,10 +935,20 @@ class SparkContext(config: SparkConf) extends Logging {
*
* @since 3.5.0
*/
- def removeJobTag(tag: String): Unit = {
- SparkContext.throwIfInvalidTag(tag)
+ def removeJobTag(tag: String): Unit = removeJobTags(Set(tag))
+
+ /**
+ * Remove multiple tags to be assigned to all the jobs started by this thread.
+ * See [[removeJobTag]] for more details.
+ *
+ * @param tags The tags to be removed. Cannot contain ',' (comma) character.
+ *
+ * @since 4.0.0
+ */
+ def removeJobTags(tags: Set[String]): Unit = {
+ tags.foreach(SparkContext.throwIfInvalidTag)
val existingTags = getJobTags()
- val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
+ val newTags = (existingTags -- tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
if (newTags.isEmpty) {
clearJobTags()
} else {
@@ -2684,6 +2705,25 @@ class SparkContext(config: SparkConf) extends Logging {
dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None)
}
+ /**
+ * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`.
+ *
+ * @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
+ * @param reason reason for cancellation.
+ * @return A future with [[ActiveJob]]s, allowing extraction of information such as Job ID and
+ * tags.
+ */
+ private[spark] def cancelJobsWithTagWithFuture(
+ tag: String,
+ reason: String): Future[Seq[ActiveJob]] = {
+ SparkContext.throwIfInvalidTag(tag)
+ assertNotStopped()
+
+ val cancelledJobs = Promise[Seq[ActiveJob]]()
+ dagScheduler.cancelJobsWithTag(tag, Some(reason), Some(cancelledJobs))
+ cancelledJobs.future
+ }
+
/**
* Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`.
*
@@ -2695,7 +2735,7 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String, reason: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
- dagScheduler.cancelJobsWithTag(tag, Option(reason))
+ dagScheduler.cancelJobsWithTag(tag, Option(reason), cancelledJobs = None)
}
/**
@@ -2708,7 +2748,7 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
- dagScheduler.cancelJobsWithTag(tag, None)
+ dagScheduler.cancelJobsWithTag(tag, reason = None, cancelledJobs = None)
}
/** Cancel all jobs that have been scheduled or are running. */
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index f8b7cdcf7a8b0..6b7fc1b0804b7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -73,7 +73,7 @@ class SparkEnv (
// We initialize the ShuffleManager later in SparkContext and Executor to allow
// user jars to define custom ShuffleManagers.
- private var _shuffleManager: ShuffleManager = _
+ @volatile private var _shuffleManager: ShuffleManager = _
def shuffleManager: ShuffleManager = _shuffleManager
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 667bf8bbc9754..e9507fa6bee48 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -25,6 +25,9 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Try
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
import org.apache.spark.{SparkConf, SparkUserAppException}
import org.apache.spark.api.python.{Py4JServer, PythonUtils}
import org.apache.spark.internal.config._
@@ -50,18 +53,21 @@ object PythonRunner {
val formattedPythonFile = formatPath(pythonFile)
val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles))
- val gatewayServer = new Py4JServer(sparkConf)
+ var gatewayServer: Option[Py4JServer] = None
+ if (sparkConf.getOption("spark.remote").isEmpty) {
+ gatewayServer = Some(new Py4JServer(sparkConf))
- val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() })
- thread.setName("py4j-gateway-init")
- thread.setDaemon(true)
- thread.start()
+ val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.get.start() })
+ thread.setName("py4j-gateway-init")
+ thread.setDaemon(true)
+ thread.start()
- // Wait until the gateway server has started, so that we know which port is it bound to.
- // `gatewayServer.start()` will start a new thread and run the server code there, after
- // initializing the socket, so the thread started above will end as soon as the server is
- // ready to serve connections.
- thread.join()
+ // Wait until the gateway server has started, so that we know which port is it bound to.
+ // `gatewayServer.start()` will start a new thread and run the server code there, after
+ // initializing the socket, so the thread started above will end as soon as the server is
+ // ready to serve connections.
+ thread.join()
+ }
// Build up a PYTHONPATH that includes the Spark assembly (where this class is), the
// python directories in SPARK_HOME (if set), and any files in the pyFiles argument
@@ -74,12 +80,22 @@ object PythonRunner {
// Launch Python process
val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
val env = builder.environment()
+ if (sparkConf.getOption("spark.remote").nonEmpty) {
+ // For non-local remote, pass configurations to environment variables so
+ // Spark Connect client sets them. For local remotes, they will be set
+ // via Py4J.
+ val grouped = sparkConf.getAll.toMap.grouped(10).toSeq
+ env.put("PYSPARK_REMOTE_INIT_CONF_LEN", grouped.length.toString)
+ grouped.zipWithIndex.foreach { case (group, idx) =>
+ env.put(s"PYSPARK_REMOTE_INIT_CONF_$idx", compact(render(group)))
+ }
+ }
sparkConf.getOption("spark.remote").foreach(url => env.put("SPARK_REMOTE", url))
env.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
- env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
- env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret)
+ gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_PORT", s.getListeningPort.toString))
+ gatewayServer.foreach(s => env.put("PYSPARK_GATEWAY_SECRET", s.secret))
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
@@ -103,7 +119,7 @@ object PythonRunner {
throw new SparkUserAppException(exitCode)
}
} finally {
- gatewayServer.shutdown()
+ gatewayServer.foreach(_.shutdown())
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 32dd2f81bbc82..2c9ddff348056 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -43,7 +43,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
extends SparkSubmitArgumentsParser with Logging {
var maybeMaster: Option[String] = None
// Global defaults. These should be keep to minimum to avoid confusing behavior.
- def master: String = maybeMaster.getOrElse("local[*]")
+ def master: String =
+ maybeMaster.getOrElse(System.getProperty("spark.test.master", "local[*]"))
var maybeRemote: Option[String] = None
var deployMode: String = null
var executorMemory: String = null
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 0e19143411e96..c5646d2956aeb 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -271,6 +271,18 @@ package object config {
.toSequence
.createWithDefault(GarbageCollectionMetrics.OLD_GENERATION_BUILTIN_GARBAGE_COLLECTORS)
+ private[spark] val EVENT_LOG_INCLUDE_TASK_METRICS_ACCUMULATORS =
+ ConfigBuilder("spark.eventLog.includeTaskMetricsAccumulators")
+ .doc("Whether to include TaskMetrics' underlying accumulator values in the event log (as " +
+ "part of the Task/Stage/Job metrics' 'Accumulables' fields. This configuration defaults " +
+ "to false because the TaskMetrics values are already logged in the 'Task Metrics' " +
+ "fields (so the accumulator updates are redundant). This flag exists only as a " +
+ "backwards-compatibility escape hatch for applications that might rely on the old " +
+ "behavior. See SPARK-42204 for details.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
private[spark] val EVENT_LOG_OVERWRITE =
ConfigBuilder("spark.eventLog.overwrite")
.version("1.0.0")
@@ -1374,7 +1386,6 @@ package object config {
private[spark] val SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR =
ConfigBuilder("spark.shuffle.accurateBlockSkewedFactor")
- .internal()
.doc("A shuffle block is considered as skewed and will be accurately recorded in " +
"HighlyCompressedMapStatus if its size is larger than this factor multiplying " +
"the median shuffle block size or SHUFFLE_ACCURATE_BLOCK_THRESHOLD. It is " +
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporter.scala b/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporter.scala
deleted file mode 100644
index bab7023ecdf11..0000000000000
--- a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushReporter.scala
+++ /dev/null
@@ -1,355 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.metrics.sink.opentelemetry
-
-import java.nio.file.{Files, Paths}
-import java.util.{Locale, SortedMap}
-import java.util.concurrent.TimeUnit
-
-import com.codahale.metrics._
-import io.opentelemetry.api.common.Attributes
-import io.opentelemetry.api.metrics.{DoubleGauge, DoubleHistogram, LongCounter}
-import io.opentelemetry.exporter.otlp.metrics.OtlpGrpcMetricExporter
-import io.opentelemetry.sdk.OpenTelemetrySdk
-import io.opentelemetry.sdk.metrics.SdkMeterProvider
-import io.opentelemetry.sdk.metrics.export.PeriodicMetricReader
-import io.opentelemetry.sdk.resources.Resource
-
-private[spark] class OpenTelemetryPushReporter(
- registry: MetricRegistry,
- pollInterval: Int = 10,
- pollUnit: TimeUnit = TimeUnit.SECONDS,
- endpoint: String = "http://localhost:4317",
- headersMap: Map[String, String] = Map(),
- attributesMap: Map[String, String] = Map(),
- trustedCertificatesPath: String,
- privateKeyPemPath: String,
- certificatePemPath: String)
- extends ScheduledReporter (
- registry,
- "opentelemetry-push-reporter",
- MetricFilter.ALL,
- TimeUnit.SECONDS,
- TimeUnit.MILLISECONDS)
- with MetricRegistryListener {
-
- val FIFTEEN_MINUTE_RATE = "_fifteen_minute_rate"
- val FIVE_MINUTE_RATE = "_five_minute_rate"
- val ONE_MINUTE_RATE = "_one_minute_rate"
- val MEAN_RATE = "_mean_rate"
- val METER = "_meter"
- val TIMER = "_timer"
- val COUNT = "_count"
- val MAX = "_max"
- val MIN = "_min"
- val MEAN = "_mean"
- val MEDIAN = "_50_percentile"
- val SEVENTY_FIFTH_PERCENTILE = "_75_percentile"
- val NINETY_FIFTH_PERCENTILE = "_95_percentile"
- val NINETY_EIGHTH_PERCENTILE = "_98_percentile"
- val NINETY_NINTH_PERCENTILE = "_99_percentile"
- val NINE_HUNDRED_NINETY_NINTH_PERCENTILE = "_999_percentile"
- val STD_DEV = "_std_dev"
-
- val otlpGrpcMetricExporterBuilder = OtlpGrpcMetricExporter.builder()
-
- for ((key, value) <- headersMap) {
- otlpGrpcMetricExporterBuilder.addHeader(key, value)
- }
-
- if (trustedCertificatesPath != null) {
- otlpGrpcMetricExporterBuilder
- .setTrustedCertificates(Files.readAllBytes(Paths.get(trustedCertificatesPath)))
- }
-
- if (privateKeyPemPath != null && certificatePemPath != null) {
- otlpGrpcMetricExporterBuilder
- .setClientTls(
- Files.readAllBytes(Paths.get(privateKeyPemPath)),
- Files.readAllBytes(Paths.get(certificatePemPath)))
- }
-
- otlpGrpcMetricExporterBuilder.setEndpoint(endpoint)
-
- val arrtributesBuilder = Attributes.builder()
- for ((key, value) <- attributesMap) {
- arrtributesBuilder.put(key, value)
- }
-
- val resource = Resource
- .getDefault()
- .merge(Resource.create(arrtributesBuilder.build()));
-
- val metricReader = PeriodicMetricReader
- .builder(otlpGrpcMetricExporterBuilder.build())
- .setInterval(pollInterval, pollUnit)
- .build()
-
- val sdkMeterProvider: SdkMeterProvider = SdkMeterProvider
- .builder()
- .registerMetricReader(metricReader)
- .setResource(resource)
- .build()
-
- val openTelemetryCounters = collection.mutable.Map[String, LongCounter]()
- val openTelemetryHistograms = collection.mutable.Map[String, DoubleHistogram]()
- val openTelemetryGauges = collection.mutable.Map[String, DoubleGauge]()
- val codahaleCounters = collection.mutable.Map[String, Counter]()
- val openTelemetry = OpenTelemetrySdk
- .builder()
- .setMeterProvider(sdkMeterProvider)
- .build();
- val openTelemetryMeter = openTelemetry.getMeter("apache-spark")
-
- override def report(
- gauges: SortedMap[String, Gauge[_]],
- counters: SortedMap[String, Counter],
- histograms: SortedMap[String, Histogram],
- meters: SortedMap[String, Meter],
- timers: SortedMap[String, Timer]): Unit = {
- counters.forEach(this.reportCounter)
- gauges.forEach(this.reportGauges)
- histograms.forEach(this.reportHistograms)
- meters.forEach(this.reportMeters)
- timers.forEach(this.reportTimers)
- sdkMeterProvider.forceFlush
- }
-
- override def onGaugeAdded(name: String, gauge: Gauge[_]): Unit = {
- val metricName = normalizeMetricName(name)
- generateGauge(metricName)
- }
-
- override def onGaugeRemoved(name: String): Unit = {
- val metricName = normalizeMetricName(name)
- openTelemetryGauges.remove(metricName)
- }
-
- override def onCounterAdded(name: String, counter: Counter): Unit = {
- val metricName = normalizeMetricName(name)
- val addedOpenTelemetryCounter =
- openTelemetryMeter.counterBuilder(normalizeMetricName(metricName)).build
- openTelemetryCounters.put(metricName, addedOpenTelemetryCounter)
- codahaleCounters.put(metricName, registry.counter(metricName))
- }
-
- override def onCounterRemoved(name: String): Unit = {
- val metricName = normalizeMetricName(name)
- openTelemetryCounters.remove(metricName)
- codahaleCounters.remove(metricName)
- }
-
- override def onHistogramAdded(name: String, histogram: Histogram): Unit = {
- val metricName = normalizeMetricName(name)
- generateHistogramGroup(metricName)
- }
-
- override def onHistogramRemoved(name: String): Unit = {
- val metricName = normalizeMetricName(name)
- cleanHistogramGroup(metricName)
- }
-
- override def onMeterAdded(name: String, meter: Meter): Unit = {
- val metricName = normalizeMetricName(name) + METER
- generateGauge(metricName + COUNT)
- generateGauge(metricName + MEAN_RATE)
- generateGauge(metricName + FIFTEEN_MINUTE_RATE)
- generateGauge(metricName + FIVE_MINUTE_RATE)
- generateGauge(metricName + ONE_MINUTE_RATE)
- }
-
- override def onMeterRemoved(name: String): Unit = {
- val metricName = normalizeMetricName(name) + METER
- openTelemetryGauges.remove(metricName + COUNT)
- openTelemetryGauges.remove(metricName + MEAN_RATE)
- openTelemetryGauges.remove(metricName + ONE_MINUTE_RATE)
- openTelemetryGauges.remove(metricName + FIVE_MINUTE_RATE)
- openTelemetryGauges.remove(metricName + FIFTEEN_MINUTE_RATE)
- }
-
- override def onTimerAdded(name: String, timer: Timer): Unit = {
- val metricName = normalizeMetricName(name) + TIMER
- generateHistogramGroup(metricName)
- generateAdditionalHistogramGroupForTimers(metricName)
- }
-
- override def onTimerRemoved(name: String): Unit = {
- val metricName = normalizeMetricName(name) + TIMER
- cleanHistogramGroup(name)
- cleanAdditionalHistogramGroupTimers(metricName)
- }
-
- override def stop(): Unit = {
- super.stop()
- sdkMeterProvider.close()
- }
-
- private def normalizeMetricName(name: String): String = {
- name.toLowerCase(Locale.ROOT).replaceAll("[^a-z0-9]", "_")
- }
-
- private def generateHistogram(metricName: String): Unit = {
- val openTelemetryHistogram =
- openTelemetryMeter.histogramBuilder(metricName).build
- openTelemetryHistograms.put(metricName, openTelemetryHistogram)
- }
-
- private def generateHistogramGroup(metricName: String): Unit = {
- generateHistogram(metricName + COUNT)
- generateHistogram(metricName + MAX)
- generateHistogram(metricName + MIN)
- generateHistogram(metricName + MEAN)
- generateHistogram(metricName + MEDIAN)
- generateHistogram(metricName + STD_DEV)
- generateHistogram(metricName + SEVENTY_FIFTH_PERCENTILE)
- generateHistogram(metricName + NINETY_FIFTH_PERCENTILE)
- generateHistogram(metricName + NINETY_EIGHTH_PERCENTILE)
- generateHistogram(metricName + NINETY_NINTH_PERCENTILE)
- generateHistogram(metricName + NINE_HUNDRED_NINETY_NINTH_PERCENTILE)
- }
-
- private def generateAdditionalHistogramGroupForTimers(metricName: String): Unit = {
- generateHistogram(metricName + FIFTEEN_MINUTE_RATE)
- generateHistogram(metricName + FIVE_MINUTE_RATE)
- generateHistogram(metricName + ONE_MINUTE_RATE)
- generateHistogram(metricName + MEAN_RATE)
- }
-
- private def cleanHistogramGroup(metricName: String): Unit = {
- openTelemetryHistograms.remove(metricName + COUNT)
- openTelemetryHistograms.remove(metricName + MAX)
- openTelemetryHistograms.remove(metricName + MIN)
- openTelemetryHistograms.remove(metricName + MEAN)
- openTelemetryHistograms.remove(metricName + MEDIAN)
- openTelemetryHistograms.remove(metricName + STD_DEV)
- openTelemetryHistograms.remove(metricName + SEVENTY_FIFTH_PERCENTILE)
- openTelemetryHistograms.remove(metricName + NINETY_FIFTH_PERCENTILE)
- openTelemetryHistograms.remove(metricName + NINETY_EIGHTH_PERCENTILE)
- openTelemetryHistograms.remove(metricName + NINETY_NINTH_PERCENTILE)
- openTelemetryHistograms.remove(metricName + NINE_HUNDRED_NINETY_NINTH_PERCENTILE)
- }
-
- private def cleanAdditionalHistogramGroupTimers(metricName: String): Unit = {
- openTelemetryHistograms.remove(metricName + FIFTEEN_MINUTE_RATE)
- openTelemetryHistograms.remove(metricName + FIVE_MINUTE_RATE)
- openTelemetryHistograms.remove(metricName + ONE_MINUTE_RATE)
- openTelemetryHistograms.remove(metricName + MEAN_RATE)
- }
-
- private def generateGauge(metricName: String): Unit = {
- val addedOpenTelemetryGauge =
- openTelemetryMeter.gaugeBuilder(normalizeMetricName(metricName)).build
- openTelemetryGauges.put(metricName, addedOpenTelemetryGauge)
- }
-
- private def reportCounter(name: String, counter: Counter): Unit = {
- val metricName = normalizeMetricName(name)
- val openTelemetryCounter = openTelemetryCounters(metricName)
- val codahaleCounter = codahaleCounters(metricName)
- val diff = counter.getCount - codahaleCounter.getCount
- openTelemetryCounter.add(diff)
- codahaleCounter.inc(diff)
- }
-
- private def reportGauges(name: String, gauge: Gauge[_]): Unit = {
- val metricName = normalizeMetricName(name)
- gauge.getValue match {
- case d: Double =>
- openTelemetryGauges(metricName).set(d.doubleValue)
- case d: Long =>
- openTelemetryGauges(metricName).set(d.doubleValue)
- case d: Int =>
- openTelemetryGauges(metricName).set(d.doubleValue)
- case _ => ()
- }
- }
-
- private def reportHistograms(name: String, histogram: Histogram): Unit = {
- val metricName = normalizeMetricName(name)
- reportHistogramGroup(metricName, histogram)
- }
-
- private def reportMeters(name: String, meter: Meter): Unit = {
- val metricName = normalizeMetricName(name) + METER
- val openTelemetryGaugeCount = openTelemetryGauges(metricName + COUNT)
- openTelemetryGaugeCount.set(meter.getCount.toDouble)
- val openTelemetryGauge0neMinuteRate = openTelemetryGauges(metricName + ONE_MINUTE_RATE)
- openTelemetryGauge0neMinuteRate.set(meter.getOneMinuteRate)
- val openTelemetryGaugeFiveMinuteRate = openTelemetryGauges(metricName + FIVE_MINUTE_RATE)
- openTelemetryGaugeFiveMinuteRate.set(meter.getFiveMinuteRate)
- val openTelemetryGaugeFifteenMinuteRate = openTelemetryGauges(
- metricName + FIFTEEN_MINUTE_RATE)
- openTelemetryGaugeFifteenMinuteRate.set(meter.getFifteenMinuteRate)
- val openTelemetryGaugeMeanRate = openTelemetryGauges(metricName + MEAN_RATE)
- openTelemetryGaugeMeanRate.set(meter.getMeanRate)
- }
-
- private def reportTimers(name: String, timer: Timer): Unit = {
- val metricName = normalizeMetricName(name) + TIMER
- val openTelemetryHistogramMax = openTelemetryHistograms(metricName + MAX)
- openTelemetryHistogramMax.record(timer.getCount.toDouble)
- val openTelemetryHistogram0neMinuteRate = openTelemetryHistograms(
- metricName + ONE_MINUTE_RATE)
- openTelemetryHistogram0neMinuteRate.record(timer.getOneMinuteRate)
- val openTelemetryHistogramFiveMinuteRate = openTelemetryHistograms(
- metricName + FIVE_MINUTE_RATE)
- openTelemetryHistogramFiveMinuteRate.record(timer.getFiveMinuteRate)
- val openTelemetryHistogramFifteenMinuteRate = openTelemetryHistograms(
- metricName + FIFTEEN_MINUTE_RATE)
- openTelemetryHistogramFifteenMinuteRate.record(timer.getFifteenMinuteRate)
- val openTelemetryHistogramMeanRate = openTelemetryHistograms(metricName + MEAN_RATE)
- openTelemetryHistogramMeanRate.record(timer.getMeanRate)
- val snapshot = timer.getSnapshot
- reportHistogramGroup(metricName, snapshot)
- }
-
- private def reportHistogramGroup(metricName: String, histogram: Histogram): Unit = {
- val openTelemetryHistogramCount = openTelemetryHistograms(metricName + COUNT)
- openTelemetryHistogramCount.record(histogram.getCount.toDouble)
- val snapshot = histogram.getSnapshot
- reportHistogramGroup(metricName, snapshot)
- }
-
- private def reportHistogramGroup(metricName: String, snapshot: Snapshot): Unit = {
- val openTelemetryHistogramMax = openTelemetryHistograms(metricName + MAX)
- openTelemetryHistogramMax.record(snapshot.getMax.toDouble)
- val openTelemetryHistogramMin = openTelemetryHistograms(metricName + MIN)
- openTelemetryHistogramMin.record(snapshot.getMin.toDouble)
- val openTelemetryHistogramMean = openTelemetryHistograms(metricName + MEAN)
- openTelemetryHistogramMean.record(snapshot.getMean)
- val openTelemetryHistogramMedian = openTelemetryHistograms(metricName + MEDIAN)
- openTelemetryHistogramMedian.record(snapshot.getMedian)
- val openTelemetryHistogramStdDev = openTelemetryHistograms(metricName + STD_DEV)
- openTelemetryHistogramStdDev.record(snapshot.getStdDev)
- val openTelemetryHistogram75Percentile = openTelemetryHistograms(
- metricName + SEVENTY_FIFTH_PERCENTILE)
- openTelemetryHistogram75Percentile.record(snapshot.get75thPercentile)
- val openTelemetryHistogram95Percentile = openTelemetryHistograms(
- metricName + NINETY_FIFTH_PERCENTILE)
- openTelemetryHistogram95Percentile.record(snapshot.get95thPercentile)
- val openTelemetryHistogram98Percentile = openTelemetryHistograms(
- metricName + NINETY_EIGHTH_PERCENTILE)
- openTelemetryHistogram98Percentile.record(snapshot.get98thPercentile)
- val openTelemetryHistogram99Percentile = openTelemetryHistograms(
- metricName + NINETY_NINTH_PERCENTILE)
- openTelemetryHistogram99Percentile.record(snapshot.get99thPercentile)
- val openTelemetryHistogram999Percentile = openTelemetryHistograms(
- metricName + NINE_HUNDRED_NINETY_NINTH_PERCENTILE)
- openTelemetryHistogram999Percentile.record(snapshot.get999thPercentile)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSink.scala
deleted file mode 100644
index 23d047f585efc..0000000000000
--- a/core/src/main/scala/org/apache/spark/metrics/sink/opentelemetry/OpenTelemetryPushSink.scala
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.metrics.sink.opentelemetry
-
-import java.util.{Locale, Properties}
-import java.util.concurrent.TimeUnit
-
-import com.codahale.metrics.MetricRegistry
-import org.apache.commons.lang3.StringUtils
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.metrics.sink.Sink
-
-private[spark] object OpenTelemetryPushSink {
- private def fetchMapFromProperties(
- properties: Properties,
- keyPrefix: String): Map[String, String] = {
- val propertiesMap = scala.collection.mutable.Map[String, String]()
- val valueEnumeration = properties.propertyNames
- val dotCount = keyPrefix.count(_ == '.')
- while (valueEnumeration.hasMoreElements) {
- val key = valueEnumeration.nextElement.asInstanceOf[String]
- if (key.startsWith(keyPrefix)) {
- val dotIndex = StringUtils.ordinalIndexOf(key, ".", dotCount + 1)
- val mapKey = key.substring(dotIndex + 1)
- propertiesMap(mapKey) = properties.getProperty(key)
- }
- }
- propertiesMap.toMap
- }
-}
-
-private[spark] class OpenTelemetryPushSink(val property: Properties, val registry: MetricRegistry)
- extends Sink with Logging {
-
- val OPEN_TELEMETRY_KEY_PERIOD = "period"
- val OPEN_TELEMETRY_KEY_UNIT = "unit"
- val OPEN_TELEMETRY_DEFAULT_PERIOD = "10"
- val OPEN_TELEMETRY_DEFAULT_UNIT = "SECONDS"
- val OPEN_TELEMETRY_KEY_ENDPOINT = "endpoint"
- val GRPC_METRIC_EXPORTER_HEADER_KEY = "grpc.metric.exporter.header"
- val GRPC_METRIC_EXPORTER_ATTRIBUTES_KEY = "grpc.metric.exporter.attributes"
- val TRUSTED_CERTIFICATE_PATH = "trusted.certificate.path"
- val PRIVATE_KEY_PEM_PATH = "private.key.pem.path"
- val CERTIFICATE_PEM_PATH = "certificate.pem.path"
-
- val pollPeriod = property
- .getProperty(OPEN_TELEMETRY_KEY_PERIOD, OPEN_TELEMETRY_DEFAULT_PERIOD)
- .toInt
-
- val pollUnit = TimeUnit.valueOf(
- property
- .getProperty(OPEN_TELEMETRY_KEY_UNIT, OPEN_TELEMETRY_DEFAULT_UNIT)
- .toUpperCase(Locale.ROOT))
-
- val endpoint = property.getProperty(OPEN_TELEMETRY_KEY_ENDPOINT)
-
- val headersMap =
- OpenTelemetryPushSink.fetchMapFromProperties(property, GRPC_METRIC_EXPORTER_HEADER_KEY)
- val attributesMap =
- OpenTelemetryPushSink.fetchMapFromProperties(property, GRPC_METRIC_EXPORTER_ATTRIBUTES_KEY)
-
- val trustedCertificatesPath: String =
- property.getProperty(TRUSTED_CERTIFICATE_PATH)
-
- val privateKeyPemPath: String = property.getProperty(PRIVATE_KEY_PEM_PATH)
-
- val certificatePemPath: String = property.getProperty(CERTIFICATE_PEM_PATH)
-
- val reporter = new OpenTelemetryPushReporter(
- registry,
- pollInterval = pollPeriod,
- pollUnit,
- endpoint,
- headersMap,
- attributesMap,
- trustedCertificatesPath,
- privateKeyPemPath,
- certificatePemPath)
-
- registry.addListener(reporter)
-
- override def start(): Unit = {
- reporter.start(pollPeriod, pollUnit)
- }
-
- override def stop(): Unit = {
- reporter.stop()
- }
-
- override def report(): Unit = {
- reporter.report()
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 6c824e2fdeaed..2c89fe7885d08 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -27,6 +27,7 @@ import scala.annotation.tailrec
import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.util.control.NonFatal
@@ -1116,11 +1117,18 @@ private[spark] class DAGScheduler(
/**
* Cancel all jobs with a given tag.
+ *
+ * @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
+ * @param reason reason for cancellation.
+ * @param cancelledJobs a promise to be completed with operation IDs being cancelled.
*/
- def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = {
+ def cancelJobsWithTag(
+ tag: String,
+ reason: Option[String],
+ cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
SparkContext.throwIfInvalidTag(tag)
logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}")
- eventProcessLoop.post(JobTagCancelled(tag, reason))
+ eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs))
}
/**
@@ -1234,17 +1242,22 @@ private[spark] class DAGScheduler(
jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
}
- private[scheduler] def handleJobTagCancelled(tag: String, reason: Option[String]): Unit = {
- // Cancel all jobs belonging that have this tag.
+ private[scheduler] def handleJobTagCancelled(
+ tag: String,
+ reason: Option[String],
+ cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
+ // Cancel all jobs that have all provided tags.
// First finds all active jobs with this group id, and then kill stages for them.
- val jobIds = activeJobs.filter { activeJob =>
+ val jobsToBeCancelled = activeJobs.filter { activeJob =>
Option(activeJob.properties).exists { properties =>
Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("")
.split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag)
}
- }.map(_.jobId)
- val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag))
- jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
+ }
+ val updatedReason =
+ reason.getOrElse("part of cancelled job tags %s".format(tag))
+ jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, Option(updatedReason)))
+ cancelledJobs.map(_.success(jobsToBeCancelled.toSeq))
}
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = {
@@ -3113,8 +3126,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case JobGroupCancelled(groupId, cancelFutureJobs, reason) =>
dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason)
- case JobTagCancelled(tag, reason) =>
- dagScheduler.handleJobTagCancelled(tag, reason)
+ case JobTagCancelled(tag, reason, cancelledJobs) =>
+ dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs)
case AllJobsCancelled =>
dagScheduler.doCancelAllJobs()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index c9ad54d1fdc7e..8932d2ef323ba 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,6 +19,8 @@ package org.apache.spark.scheduler
import java.util.Properties
+import scala.concurrent.Promise
+
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{AccumulatorV2, CallSite}
@@ -71,7 +73,8 @@ private[scheduler] case class JobGroupCancelled(
private[scheduler] case class JobTagCancelled(
tagName: String,
- reason: Option[String]) extends DAGSchedulerEvent
+ reason: Option[String],
+ cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index efd8fecb974e8..1e46142fab255 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -31,7 +31,7 @@ import org.apache.spark.deploy.history.EventLogFileWriter
import org.apache.spark.executor.ExecutorMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
-import org.apache.spark.util.{JsonProtocol, Utils}
+import org.apache.spark.util.{JsonProtocol, JsonProtocolOptions, Utils}
/**
* A SparkListener that logs events to persistent storage.
@@ -74,6 +74,8 @@ private[spark] class EventLoggingListener(
private val liveStageExecutorMetrics =
mutable.HashMap.empty[(Int, Int), mutable.HashMap[String, ExecutorMetrics]]
+ private[this] val jsonProtocolOptions = new JsonProtocolOptions(sparkConf)
+
/**
* Creates the log file in the configured log directory.
*/
@@ -84,7 +86,7 @@ private[spark] class EventLoggingListener(
private def initEventLog(): Unit = {
val metadata = SparkListenerLogStart(SPARK_VERSION)
- val eventJson = JsonProtocol.sparkEventToJsonString(metadata)
+ val eventJson = JsonProtocol.sparkEventToJsonString(metadata, jsonProtocolOptions)
logWriter.writeEvent(eventJson, flushLogger = true)
if (testing && loggedEvents != null) {
loggedEvents += eventJson
@@ -93,7 +95,7 @@ private[spark] class EventLoggingListener(
/** Log the event as JSON. */
private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false): Unit = {
- val eventJson = JsonProtocol.sparkEventToJsonString(event)
+ val eventJson = JsonProtocol.sparkEventToJsonString(event, jsonProtocolOptions)
logWriter.writeEvent(eventJson, flushLogger)
if (testing) {
loggedEvents += eventJson
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index cc19b71bfc4d6..384f939a843bc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -22,8 +22,6 @@ import javax.annotation.Nullable
import scala.collection.Map
-import com.fasterxml.jackson.annotation.JsonTypeInfo
-
import org.apache.spark.TaskEndReason
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
@@ -31,13 +29,6 @@ import org.apache.spark.resource.ResourceProfile
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo}
-@DeveloperApi
-@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event")
-trait SparkListenerEvent {
- /* Whether output this event to the event log */
- protected[spark] def logEvent: Boolean = true
-}
-
@DeveloperApi
case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null)
extends SparkListenerEvent
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
index 686ac1eb786e0..f29e8778da037 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerStorageEndpoint.scala
@@ -60,7 +60,13 @@ class BlockManagerStorageEndpoint(
if (mapOutputTracker != null) {
mapOutputTracker.unregisterShuffle(shuffleId)
}
- SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
+ val shuffleManager = SparkEnv.get.shuffleManager
+ if (shuffleManager != null) {
+ shuffleManager.unregisterShuffle(shuffleId)
+ } else {
+ logDebug(log"Ignore remove shuffle ${MDC(SHUFFLE_ID, shuffleId)}")
+ true
+ }
}
case DecommissionBlockManager =>
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 1498b224b0c92..3e57094b36a7e 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -35,6 +35,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils}
import org.apache.spark.security.CryptoStreamUtils
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -324,7 +325,7 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize:
private var _transferred = 0L
- private val buffer = ByteBuffer.allocateDirect(64 * 1024)
+ private val buffer = Platform.allocateDirectBuffer(64 * 1024)
buffer.flip()
override def count(): Long = blockSize
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index c3dc459b0f88e..fff6ec4f5b170 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -449,16 +449,24 @@ private[spark] object UIUtils extends Logging {
val startRatio = if (total == 0) 0.0 else (boundedStarted.toDouble / total) * 100
val startWidth = "width: %s%%".format(startRatio)
+ val killTaskReasonText = reasonToNumKilled.toSeq.sortBy(-_._2).map {
+ case (reason, count) => s" ($count killed: $reason)"
+ }.mkString
+ val progressTitle = s"$completed/$total" + {
+ if (started > 0) s" ($started running)" else ""
+ } + {
+ if (failed > 0) s" ($failed failed)" else ""
+ } + {
+ if (skipped > 0) s" ($skipped skipped)" else ""
+ } + killTaskReasonText
+
Enables converting Protobuf Any fields to JSON. This option should be enabled carefully. JSON conversion and processing are inefficient. In addition, schema safety is also reduced making downstream processing error-prone.
+
read
+
+
+
emit.default.values
+
false
+
Whether to render fields with zero values when deserializing Protobuf to a Spark struct. When a field is empty in the serialized Protobuf, this library will deserialize them as null by default, this option can control whether to render the type-specific zero values.
+
read
+
+
+
enums.as.ints
+
false
+
Whether to render enum fields as their integer values. When this option set to false, an enum field will be mapped to StringType, and the value is the name of enum; when set to true, an enum field will be mapped to IntegerType, the value is its integer value.
+
read
+
+
+
upcast.unsigned.ints
+
false
+
Whether to upcast unsigned integers into a larger type. Setting this option to true, LongType is used for uint32 and Decimal(20, 0) is used for uint64, so their representation can contain large unsigned values without overflow.
+
read
+
+
+
unwrap.primitive.wrapper.types
+
false
+
Whether to unwrap the struct representation for well-known primitive wrapper types when deserializing. By default, the wrapper types for primitives (i.e. google.protobuf.Int32Value, google.protobuf.Int64Value, etc.) will get deserialized as structs.
+
read
+
+
+
retain.empty.message.types
+
false
+
Whether to retain fields of the empty proto message type in Schema. Since Spark doesn't allow writing empty StructType, the empty proto message type will be dropped by default. Setting this option to true will insert a dummy column(__dummy_field_in_empty_struct) to the empty proto message so that the empty message fields will be retained.
+
read
+
+
diff --git a/docs/sql-distributed-sql-engine.md b/docs/sql-distributed-sql-engine.md
index 734723f8c6235..ae8fd9c7211bd 100644
--- a/docs/sql-distributed-sql-engine.md
+++ b/docs/sql-distributed-sql-engine.md
@@ -83,7 +83,7 @@ Use the following setting to enable HTTP mode as system property or in `hive-sit
To test, use beeline to connect to the JDBC/ODBC server in http mode with:
- beeline> !connect jdbc:hive2://:/?hive.server2.transport.mode=http;hive.server2.thrift.http.path=
+ beeline> !connect jdbc:hive2://:/;transportMode=http;httpPath=
If you closed a session and do CTAS, you must set `fs.%s.impl.disable.cache` to true in `hive-site.xml`.
See more details in [[SPARK-21067]](https://issues.apache.org/jira/browse/SPARK-21067).
@@ -94,4 +94,4 @@ To use the Spark SQL command line interface (CLI) from the shell:
./bin/spark-sql
-For details, please refer to [Spark SQL CLI](sql-distributed-sql-engine-spark-sql-cli.html)
\ No newline at end of file
+For details, please refer to [Spark SQL CLI](sql-distributed-sql-engine-spark-sql-cli.html)
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index ad678c44657ed..0ecd45c2d8c56 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -60,6 +60,7 @@ license: |
- Since Spark 4.0, By default views tolerate column type changes in the query and compensate with casts. To restore the previous behavior, allowing up-casts only, set `spark.sql.legacy.viewSchemaCompensation` to `false`.
- Since Spark 4.0, Views allow control over how they react to underlying query changes. By default views tolerate column type changes in the query and compensate with casts. To disable this feature set `spark.sql.legacy.viewSchemaBindingMode` to `false`. This also removes the clause from `DESCRIBE EXTENDED` and `SHOW CREATE TABLE`.
- Since Spark 4.0, The Storage-Partitioned Join feature flag `spark.sql.sources.v2.bucketing.pushPartValues.enabled` is set to `true`. To restore the previous behavior, set `spark.sql.sources.v2.bucketing.pushPartValues.enabled` to `false`.
+- Since Spark 4.0, the `sentences` function uses `Locale(language)` instead of `Locale.US` when `language` parameter is not `NULL` and `country` parameter is `NULL`.
## Upgrading from Spark SQL 3.5.1 to 3.5.2
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index f5e1ddfd3c576..12dff1e325c49 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -141,13 +141,13 @@ In the table above, all the `CAST`s with new syntax are marked as red buildClassPath(String appClassPath) throws IOException {
boolean prependClasses = !isEmpty(getenv("SPARK_PREPEND_CLASSES"));
boolean isTesting = "1".equals(getenv("SPARK_TESTING"));
+ boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING"));
+ String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql);
if (prependClasses || isTesting) {
String scala = getScalaVersion();
List projects = Arrays.asList(
@@ -176,6 +178,9 @@ List buildClassPath(String appClassPath) throws IOException {
"NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " +
"assembly.");
}
+ boolean shouldPrePendSparkHive = isJarAvailable(jarsDir, "spark-hive_");
+ boolean shouldPrePendSparkHiveThriftServer =
+ shouldPrePendSparkHive && isJarAvailable(jarsDir, "spark-hive-thriftserver_");
for (String project : projects) {
// Do not use locally compiled class files for Spark server because it should use shaded
// dependencies.
@@ -185,6 +190,24 @@ List buildClassPath(String appClassPath) throws IOException {
if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL")) && project.equals("sql/core")) {
continue;
}
+ // SPARK-49534: The assumption here is that if `spark-hive_xxx.jar` is not in the
+ // classpath, then the `-Phive` profile was not used during package, and therefore
+ // the Hive-related jars should also not be in the classpath. To avoid failure in
+ // loading the SPI in `DataSourceRegister` under `sql/hive`, no longer prepend `sql/hive`.
+ if (!shouldPrePendSparkHive && project.equals("sql/hive")) {
+ continue;
+ }
+ // SPARK-49534: Meanwhile, due to the strong dependency of `sql/hive-thriftserver`
+ // on `sql/hive`, the prepend for `sql/hive-thriftserver` will also be excluded
+ // if `spark-hive_xxx.jar` is not in the classpath. On the other hand, if
+ // `spark-hive-thriftserver_xxx.jar` is not in the classpath, then the
+ // `-Phive-thriftserver` profile was not used during package, and therefore,
+ // jars such as hive-cli and hive-beeline should also not be included in the classpath.
+ // To avoid the inelegant startup failures of tools such as spark-sql, in this scenario,
+ // `sql/hive-thriftserver` will no longer be prepended to the classpath.
+ if (!shouldPrePendSparkHiveThriftServer && project.equals("sql/hive-thriftserver")) {
+ continue;
+ }
addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project,
scala));
}
@@ -205,8 +228,6 @@ List buildClassPath(String appClassPath) throws IOException {
// Add Spark jars to the classpath. For the testing case, we rely on the test code to set and
// propagate the test classpath appropriately. For normal invocation, look for the jars
// directory under SPARK_HOME.
- boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING"));
- String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql);
if (jarsDir != null) {
// Place slf4j-api-* jar first to be robust
for (File f: new File(jarsDir).listFiles()) {
@@ -214,7 +235,9 @@ List buildClassPath(String appClassPath) throws IOException {
addToClassPath(cp, f.toString());
}
}
- if (isRemote && "1".equals(getenv("SPARK_SCALA_SHELL"))) {
+ // If we're in 'spark.local.connect', it should create a Spark Classic Spark Context
+ // that launches Spark Connect server.
+ if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) {
for (File f: new File(jarsDir).listFiles()) {
// Exclude Spark Classic SQL and Spark Connect server jars
// if we're in Spark Connect Shell. Also exclude Spark SQL API and
@@ -263,6 +286,24 @@ private void addToClassPath(Set cp, String entries) {
}
}
+ /**
+ * Checks if a JAR file with a specific prefix is available in the given directory.
+ *
+ * @param jarsDir the directory to search for JAR files
+ * @param jarNamePrefix the prefix of the JAR file name to look for
+ * @return true if a JAR file with the specified prefix is found, false otherwise
+ */
+ private boolean isJarAvailable(String jarsDir, String jarNamePrefix) {
+ if (jarsDir != null) {
+ for (File f : new File(jarsDir).listFiles()) {
+ if (f.getName().startsWith(jarNamePrefix)) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
String getScalaVersion() {
String scala = getenv("SPARK_SCALA_VERSION");
if (scala != null) {
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
index eebd04fe4c5b1..8d95bc06d7a7d 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
@@ -82,10 +82,6 @@ public List buildCommand(Map env)
javaOptsKeys.add("SPARK_BEELINE_OPTS");
yield "SPARK_BEELINE_MEMORY";
}
- case "org.apache.spark.sql.application.ConnectRepl" -> {
- isRemote = true;
- yield "SPARK_DRIVER_MEMORY";
- }
default -> "SPARK_DRIVER_MEMORY";
};
diff --git a/licenses-binary/LICENSE-xz.txt b/licenses-binary/LICENSE-xz.txt
new file mode 100644
index 0000000000000..4322122aecf1a
--- /dev/null
+++ b/licenses-binary/LICENSE-xz.txt
@@ -0,0 +1,11 @@
+Permission to use, copy, modify, and/or distribute this
+software for any purpose with or without fee is hereby granted.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
+WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
+CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index b961f97cd877f..f97fefa245145 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -128,7 +128,7 @@ class CrossValidatorSuite
exception = intercept[SparkIllegalArgumentException] {
cv.fit(datasetWithFold)
},
- errorClass = "FIELD_NOT_FOUND",
+ condition = "FIELD_NOT_FOUND",
parameters = Map(
"fieldName" -> "`fold1`",
"fields" -> "`label`, `features`, `fold`")
diff --git a/pom.xml b/pom.xml
index 1cf74ed907ab1..b7c87beec0f92 100644
--- a/pom.xml
+++ b/pom.xml
@@ -124,7 +124,7 @@
3.4.0
- 3.25.4
+ 3.25.53.11.4${hadoop.version}3.9.2
@@ -137,7 +137,7 @@
3.8.010.16.1.1
- 1.14.1
+ 1.14.22.0.2shaded-protobuf11.0.23
@@ -183,11 +183,11 @@
2.17.22.17.22.3.1
- 1.1.10.6
+ 1.1.10.73.0.31.17.11.27.1
- 2.16.1
+ 2.17.02.6
@@ -195,11 +195,11 @@
2.12.04.1.17
- 14.0.1
+ 33.2.1-jre2.11.03.1.93.0.12
- 2.12.7
+ 2.13.03.5.23.0.02.2.11
@@ -212,10 +212,10 @@
1.1.01.9.01.78
- 1.14.1
+ 1.15.06.0.04.1.110.Final
- 2.0.65.Final
+ 2.0.66.Final75.15.11.01.11.0
@@ -227,6 +227,7 @@
-->
17.0.03.0.0-M2
+ 0.12.6org.fusesource.leveldbjni
@@ -276,6 +277,7 @@
compilecompiletest
+ testfalse
@@ -677,6 +679,23 @@
ivy${ivy.version}
+
+ io.jsonwebtoken
+ jjwt-api
+ ${jjwt.version}
+
+
+ io.jsonwebtoken
+ jjwt-impl
+ ${jjwt.version}
+ ${jjwt.deps.scope}
+
+
+ io.jsonwebtoken
+ jjwt-jackson
+ ${jjwt.version}
+ ${jjwt.deps.scope}
+ com.google.code.findbugsjsr305
@@ -2615,7 +2634,7 @@
io.airliftaircompressor
- 0.27
+ 2.0.2org.apache.orc
@@ -3401,6 +3420,7 @@
org.spark-project.spark:unusedcom.google.guava:guava
+ com.google.guava:failureaccessorg.jpmml:*
@@ -3809,6 +3829,9 @@
sparkr
+
+ jjwt
+ aarch64
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 21638e4816309..ece4504395f12 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -125,8 +125,64 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"),
+ // SPARK-49414: Remove Logging from DataFrameReader.
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.DataFrameReader"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logName"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.log"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.isTraceEnabled"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary$default$2"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeForcefully"),
+
// SPARK-49425: Create a shared DataFrameWriter interface.
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter")
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter"),
+
+ // SPARK-49284: Shared Catalog interface.
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.CatalogMetadata"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Column"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Database"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Function"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.catalog.Table"),
+
+ // SPARK-49426: Shared DataFrameWriterV2
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CreateTableWriter"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriterV2"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.WriteConfigMethods"),
+
+ // SPARK-49424: Shared Encoders
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$"),
+
+ // SPARK-49413: Create a shared RuntimeConfig interface.
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RuntimeConfig$"),
+
+ // SPARK-49287: Shared Streaming interfaces
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.SparkListenerEvent"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SourceProgress"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SourceProgress$"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StateOperatorProgress"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StateOperatorProgress$"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$Event"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryIdleEvent"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryStatus"),
)
// Default exclude rules
@@ -145,6 +201,8 @@ object MimaExcludes {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"),
+ ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"),
+ ProblemFilters.exclude[Problem]("org.apache.spark.sql.connect.*"),
// DSv2 catalog and expression APIs are unstable yet. We should enable this back.
ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"),
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b9d1d62c5ca5a..2f390cb70baa8 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -89,7 +89,7 @@ object BuildCommons {
// Google Protobuf version used for generating the protobuf.
// SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`.
- val protoVersion = "3.25.4"
+ val protoVersion = "3.25.5"
// GRPC version used for Spark Connect.
val grpcVersion = "1.62.2"
}
@@ -420,11 +420,6 @@ object SparkBuild extends PomBuild {
enable(DockerIntegrationTests.settings)(dockerIntegrationTests)
- if (!profiles.contains("volcano")) {
- enable(Volcano.settings)(kubernetes)
- enable(Volcano.settings)(kubernetesIntegrationTests)
- }
-
enable(KubernetesIntegrationTests.settings)(kubernetesIntegrationTests)
enable(YARN.settings)(yarn)
@@ -433,10 +428,6 @@ object SparkBuild extends PomBuild {
enable(SparkR.settings)(core)
}
- if (!profiles.contains("opentelemetry")) {
- enable(OpenTelemetry.settings)(core)
- }
-
/**
* Adds the ability to run the spark shell directly from SBT without building an assembly
* jar.
@@ -1060,7 +1051,7 @@ object KubernetesIntegrationTests {
* Overrides to work around sbt's dependency resolution being different from Maven's.
*/
object DependencyOverrides {
- lazy val guavaVersion = sys.props.get("guava.version").getOrElse("14.0.1")
+ lazy val guavaVersion = sys.props.get("guava.version").getOrElse("33.1.0-jre")
lazy val settings = Seq(
dependencyOverrides += "com.google.guava" % "guava" % guavaVersion,
dependencyOverrides += "xerces" % "xercesImpl" % "2.12.2",
@@ -1326,20 +1317,6 @@ object SparkR {
)
}
-object Volcano {
- // Exclude all volcano file for Compile and Test
- lazy val settings = Seq(
- unmanagedSources / excludeFilter := HiddenFileFilter || "*Volcano*.scala"
- )
-}
-
-object OpenTelemetry {
- // Exclude all OpenTelemetry files for Compile and Test
- lazy val settings = Seq(
- unmanagedSources / excludeFilter := HiddenFileFilter || "OpenTelemetry*.scala"
- )
-}
-
trait SharedUnidocSettings {
import BuildCommons._
@@ -1375,6 +1352,7 @@ trait SharedUnidocSettings {
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect/")))
+ .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/classic/")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalog/v2/utils")))
diff --git a/python/docs/Makefile b/python/docs/Makefile
index 5058c1206171b..428b0d24b568e 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -16,7 +16,7 @@
# Minimal makefile for Sphinx documentation
# You can set these variables from the command line.
-SPHINXOPTS ?= "-W" "-j" "auto"
+SPHINXOPTS ?= "-W" "-j" "4"
SPHINXBUILD ?= sphinx-build
SOURCEDIR ?= source
BUILDDIR ?= build
diff --git a/python/docs/source/_static/spark-logo-dark.png b/python/docs/source/_static/spark-logo-dark.png
deleted file mode 100644
index 7460faec37fc7..0000000000000
Binary files a/python/docs/source/_static/spark-logo-dark.png and /dev/null differ
diff --git a/python/docs/source/_static/spark-logo-light.png b/python/docs/source/_static/spark-logo-light.png
deleted file mode 100644
index 41938560822ca..0000000000000
Binary files a/python/docs/source/_static/spark-logo-light.png and /dev/null differ
diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py
index 66b985092faf1..5640ba151176d 100644
--- a/python/docs/source/conf.py
+++ b/python/docs/source/conf.py
@@ -205,8 +205,8 @@
"navbar_end": ["version-switcher", "theme-switcher", "navbar-icon-links"],
"footer_start": ["spark_footer", "sphinx-version"],
"logo": {
- "image_light": "_static/spark-logo-light.png",
- "image_dark": "_static/spark-logo-dark.png",
+ "image_light": "https://spark.apache.org/images/spark-logo.png",
+ "image_dark": "https://spark.apache.org/images/spark-logo-rev.svg",
},
"icon_links": [
{
@@ -234,7 +234,7 @@
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
-html_logo = "../../../docs/img/spark-logo-reverse.png"
+html_logo = "https://spark.apache.org/images/spark-logo-rev.svg"
# The name of an image file (within the static path) to use as a favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst
index 549656bea103e..88c0a8c26cc94 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -183,6 +183,7 @@ Package Supported version Note
Additional libraries that enhance functionality but are not included in the installation packages:
- **memory-profiler**: Used for PySpark UDF memory profiling, ``spark.profile.show(...)`` and ``spark.sql.pyspark.udf.profiler``.
+- **plotly**: Used for PySpark plotting, ``DataFrame.plot``.
Note that PySpark requires Java 17 or later with ``JAVA_HOME`` properly set and refer to |downloading|_.
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index dc4329c603241..4910a5b59273b 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -553,6 +553,7 @@ VARIANT Functions
try_variant_get
variant_get
try_parse_json
+ to_variant_object
XML Functions
diff --git a/python/docs/source/user_guide/sql/python_data_source.rst b/python/docs/source/user_guide/sql/python_data_source.rst
index 342b6f685d0b4..832987d19e5a4 100644
--- a/python/docs/source/user_guide/sql/python_data_source.rst
+++ b/python/docs/source/user_guide/sql/python_data_source.rst
@@ -452,3 +452,67 @@ We can also use the same data source in streaming reader and writer
.. code-block:: python
query = spark.readStream.format("fake").load().writeStream.format("fake").start("/output_path")
+
+Python Data Source Reader with direct Arrow Batch support for improved performance
+----------------------------------------------------------------------------------
+The Python Datasource Reader supports direct yielding of Arrow Batches, which can significantly improve data processing performance. By using the efficient Arrow format,
+this feature avoids the overhead of traditional row-by-row data processing, resulting in performance improvements of up to one order of magnitude, especially with large datasets.
+
+**Enabling Arrow Batch Support**:
+To enable this feature, configure your custom DataSource to yield Arrow batches by returning `pyarrow.RecordBatch` objects within the `read` method of your `DataSourceReader`
+(or `DataSourceStreamReader`) implementation. This method simplifies data handling and reduces the number of I/O operations, particularly beneficial for large-scale data processing tasks.
+
+**Arrow Batch Example**:
+The following example demonstrates how to implement a basic Data Source using Arrow Batch support.
+
+.. code-block:: python
+
+ from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
+ from pyspark.sql import SparkSession
+ import pyarrow as pa
+
+ # Define the ArrowBatchDataSource
+ class ArrowBatchDataSource(DataSource):
+ """
+ A Data Source for testing Arrow Batch Serialization
+ """
+
+ @classmethod
+ def name(cls):
+ return "arrowbatch"
+
+ def schema(self):
+ return "key int, value string"
+
+ def reader(self, schema: str):
+ return ArrowBatchDataSourceReader(schema, self.options)
+
+ # Define the ArrowBatchDataSourceReader
+ class ArrowBatchDataSourceReader(DataSourceReader):
+ def __init__(self, schema, options):
+ self.schema: str = schema
+ self.options = options
+
+ def read(self, partition):
+ # Create Arrow Record Batch
+ keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
+ values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
+ schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
+ record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
+ yield record_batch
+
+ def partitions(self):
+ # Define the number of partitions
+ num_part = 1
+ return [InputPartition(i) for i in range(num_part)]
+
+ # Initialize the Spark Session
+ spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()
+
+ # Register the ArrowBatchDataSource
+ spark.dataSource.register(ArrowBatchDataSource)
+
+ # Load data using the custom data source
+ df = spark.read.format("arrowbatch").load()
+
+ df.show()
diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py
index 79b74483f00dd..17cca326d0241 100755
--- a/python/packaging/classic/setup.py
+++ b/python/packaging/classic/setup.py
@@ -288,6 +288,7 @@ def run(self):
"pyspark.sql.connect.streaming.worker",
"pyspark.sql.functions",
"pyspark.sql.pandas",
+ "pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
"pyspark.sql.worker",
diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py
index ab166c79747df..6ae16e9a9ad3a 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -77,6 +77,7 @@
"pyspark.sql.tests.connect.client",
"pyspark.sql.tests.connect.shell",
"pyspark.sql.tests.pandas",
+ "pyspark.sql.tests.plot",
"pyspark.sql.tests.streaming",
"pyspark.ml.tests.connect",
"pyspark.pandas.tests",
@@ -161,6 +162,7 @@
"pyspark.sql.connect.streaming.worker",
"pyspark.sql.functions",
"pyspark.sql.pandas",
+ "pyspark.sql.plot",
"pyspark.sql.protobuf",
"pyspark.sql.streaming",
"pyspark.sql.worker",
diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json
index 4061d024a83cd..92aeb15e21d1b 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -1088,6 +1088,11 @@
"Function `` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments."
]
},
+ "UNSUPPORTED_PLOT_BACKEND": {
+ "message": [
+ "`` is not supported, it should be one of the values from "
+ ]
+ },
"UNSUPPORTED_SIGNATURE": {
"message": [
"Unsupported signature: ."
diff --git a/python/pyspark/pandas/config.py b/python/pyspark/pandas/config.py
index bfa88253dc6f4..6ed4adf21ff44 100644
--- a/python/pyspark/pandas/config.py
+++ b/python/pyspark/pandas/config.py
@@ -287,7 +287,8 @@ def validate(self, v: Any) -> None:
doc=(
"'plotting.sample_ratio' sets the proportion of data that will be plotted for sample-"
"based plots such as `plot.line` and `plot.area`. "
- "This option defaults to 'plotting.max_rows' option."
+ "If not set, it is derived from 'plotting.max_rows', by calculating the ratio of "
+ "'plotting.max_rows' to the total data size."
),
default=None,
types=(float, type(None)),
diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py
index 92d4a3357319f..4be345201ba65 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -43,6 +43,7 @@
)
from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote
from pyspark import pandas as ps
+from pyspark.pandas.spark import functions as SF
from pyspark.pandas._typing import Label
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
from pyspark.pandas.data_type_ops.base import DataTypeOps
@@ -938,19 +939,10 @@ def attach_distributed_sequence_column(
+--------+---+
"""
if len(sdf.columns) > 0:
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
- from pyspark.sql.connect.expressions import DistributedSequenceID
-
- return sdf.select(
- ConnectColumn(DistributedSequenceID()).alias(column_name),
- "*",
- )
- else:
- return PySparkDataFrame(
- sdf._jdf.toDF().withSequenceColumn(column_name),
- sdf.sparkSession,
- )
+ return sdf.select(
+ SF.distributed_sequence_id().alias(column_name),
+ "*",
+ )
else:
cnt = sdf.count()
if cnt > 0:
diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py
index ea76dfa25bd99..6f036b7669246 100644
--- a/python/pyspark/pandas/plot/core.py
+++ b/python/pyspark/pandas/plot/core.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import bisect
import importlib
import math
@@ -25,7 +24,6 @@
from pandas.core.dtypes.inference import is_integer
from pyspark.sql import functions as F, Column
-from pyspark.sql.types import DoubleType
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
@@ -70,19 +68,52 @@ class SampledPlotBase:
def get_sampled(self, data):
from pyspark.pandas import DataFrame, Series
+ if not isinstance(data, (DataFrame, Series)):
+ raise TypeError("Only DataFrame and Series are supported for plotting.")
+ if isinstance(data, Series):
+ data = data.to_frame()
+
fraction = get_option("plotting.sample_ratio")
- if fraction is None:
- fraction = 1 / (len(data) / get_option("plotting.max_rows"))
- fraction = min(1.0, fraction)
- self.fraction = fraction
-
- if isinstance(data, (DataFrame, Series)):
- if isinstance(data, Series):
- data = data.to_frame()
+ if fraction is not None:
+ self.fraction = fraction
sampled = data._internal.resolved_copy.spark_frame.sample(fraction=self.fraction)
return DataFrame(data._internal.with_new_sdf(sampled))._to_pandas()
else:
- raise TypeError("Only DataFrame and Series are supported for plotting.")
+ from pyspark.sql import Observation
+
+ max_rows = get_option("plotting.max_rows")
+ observation = Observation("ps plotting")
+ sdf = data._internal.resolved_copy.spark_frame.observe(
+ observation, F.count(F.lit(1)).alias("count")
+ )
+
+ rand_col_name = "__ps_plotting_sampled_plot_base_rand__"
+ id_col_name = "__ps_plotting_sampled_plot_base_id__"
+
+ sampled = (
+ sdf.select(
+ "*",
+ F.rand().alias(rand_col_name),
+ F.monotonically_increasing_id().alias(id_col_name),
+ )
+ .sort(rand_col_name)
+ .limit(max_rows + 1)
+ .coalesce(1)
+ .sortWithinPartitions(id_col_name)
+ .drop(rand_col_name, id_col_name)
+ )
+
+ pdf = DataFrame(data._internal.with_new_sdf(sampled))._to_pandas()
+
+ if len(pdf) > max_rows:
+ try:
+ self.fraction = float(max_rows) / observation.get["count"]
+ except Exception:
+ pass
+ return pdf[:max_rows]
+ else:
+ self.fraction = 1.0
+ return pdf
def set_result_text(self, ax):
assert hasattr(self, "fraction")
@@ -182,22 +213,16 @@ def compute_hist(psdf, bins):
colnames = sdf.columns
bucket_names = ["__{}_bucket".format(colname) for colname in colnames]
- # TODO(SPARK-49202): register this function in scala side
- @F.udf(returnType=DoubleType())
- def binary_search_for_buckets(value):
- # Given bins = [1.0, 2.0, 3.0, 4.0]
- # the intervals are:
- # [1.0, 2.0) -> 0.0
- # [2.0, 3.0) -> 1.0
- # [3.0, 4.0] -> 2.0 (the last bucket is a closed interval)
- if value < bins[0] or value > bins[-1]:
- raise ValueError(f"value {value} out of the bins bounds: [{bins[0]}, {bins[-1]}]")
-
- if value == bins[-1]:
- idx = len(bins) - 2
- else:
- idx = bisect.bisect(bins, value) - 1
- return float(idx)
+ # refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets
+ def binary_search_for_buckets(value: Column):
+ index = SF.binary_search(F.lit(bins), value)
+ bucket = F.when(index >= 0, index).otherwise(-index - 2)
+ unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]")
+ return (
+ F.when(value == F.lit(bins[-1]), F.lit(len(bins) - 2))
+ .when(value.between(F.lit(bins[0]), F.lit(bins[-1])), bucket)
+ .otherwise(F.raise_error(F.printf(unboundErrMsg, value)))
+ )
output_df = (
sdf.select(
@@ -205,10 +230,10 @@ def binary_search_for_buckets(value):
F.array([F.col(colname).cast("double") for colname in colnames])
).alias("__group_id", "__value")
)
- # to match handleInvalid="skip" in Bucketizer
- .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN()).select(
+ .where(F.col("__value").isNotNull() & ~F.col("__value").isNaN())
+ .select(
F.col("__group_id"),
- binary_search_for_buckets(F.col("__value")).alias("__bucket"),
+ binary_search_for_buckets(F.col("__value")).cast("double").alias("__bucket"),
)
)
@@ -454,7 +479,7 @@ class PandasOnSparkPlotAccessor(PandasObject):
"pie": TopNPlotBase().get_top_n,
"bar": TopNPlotBase().get_top_n,
"barh": TopNPlotBase().get_top_n,
- "scatter": TopNPlotBase().get_top_n,
+ "scatter": SampledPlotBase().get_sampled,
"area": SampledPlotBase().get_sampled,
"line": SampledPlotBase().get_sampled,
}
@@ -548,7 +573,7 @@ def line(self, x=None, y=None, **kwargs):
"""
Plot DataFrame/Series as lines.
- This function is useful to plot lines using Series's values
+ This function is useful to plot lines using DataFrame’s values
as coordinates.
Parameters
@@ -614,6 +639,12 @@ def bar(self, x=None, y=None, **kwds):
"""
Vertical bar plot.
+ A bar plot is a plot that presents categorical data with rectangular
+ bars with lengths proportional to the values that they represent. A
+ bar plot shows comparisons among discrete categories. One axis of the
+ plot shows the specific categories being compared, and the other axis
+ represents a measured value.
+
Parameters
----------
x : label or position, optional
@@ -725,10 +756,10 @@ def barh(self, x=None, y=None, **kwargs):
Parameters
----------
- x : label or position, default DataFrame.index
- Column to be used for categories.
- y : label or position, default All numeric columns in dataframe
+ x : label or position, default All numeric columns in dataframe
Columns to be plotted from the DataFrame.
+ y : label or position, default DataFrame.index
+ Column to be used for categories.
**kwds
Keyword arguments to pass on to
:meth:`pyspark.pandas.DataFrame.plot` or :meth:`pyspark.pandas.Series.plot`.
@@ -739,6 +770,13 @@ def barh(self, x=None, y=None, **kwargs):
Return an custom object when ``backend!=plotly``.
Return an ndarray when ``subplots=True`` (matplotlib-only).
+ Notes
+ -----
+ In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs.
+ In Plotly, `x` refers to the values and `y` refers to the categories.
+ In Matplotlib, `x` refers to the categories and `y` refers to the values.
+ Ensure correct axis labeling based on the backend used.
+
See Also
--------
plotly.express.bar : Plot a vertical bar plot using plotly.
@@ -805,7 +843,17 @@ def barh(self, x=None, y=None, **kwargs):
def box(self, **kwds):
"""
- Make a box plot of the Series columns.
+ Make a box plot of the DataFrame columns.
+
+ A box plot is a method for graphically depicting groups of numerical data through
+ their quartiles. The box extends from the Q1 to Q3 quartile values of the data,
+ with a line at the median (Q2). The whiskers extend from the edges of box to show
+ the range of the data. The position of the whiskers is set by default to
+ 1.5*IQR (IQR = Q3 - Q1) from the edges of the box. Outlier points are those past
+ the end of the whiskers.
+
+ A consideration when using this chart is that the box and the whiskers can overlap,
+ which is very common when plotting small sets of data.
Parameters
----------
@@ -859,9 +907,11 @@ def box(self, **kwds):
def hist(self, bins=10, **kwds):
"""
Draw one histogram of the DataFrame’s columns.
+
A `histogram`_ is a representation of the distribution of data.
This function calls :meth:`plotting.backend.plot`,
on each series in the DataFrame, resulting in one histogram per column.
+ This is useful when the DataFrame’s Series are in a similar scale.
.. _histogram: https://en.wikipedia.org/wiki/Histogram
@@ -910,6 +960,10 @@ def kde(self, bw_method=None, ind=None, **kwargs):
"""
Generate Kernel Density Estimate plot using Gaussian kernels.
+ In statistics, kernel density estimation (KDE) is a non-parametric way to
+ estimate the probability density function (PDF) of a random variable. This
+ function uses Gaussian kernels and includes automatic bandwidth determination.
+
Parameters
----------
bw_method : scalar
diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py
index 6bef3d9b87c05..4bcf07f6f6503 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -174,6 +174,18 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
+def distributed_sequence_id() -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions.builtin import _invoke_function
+
+ return _invoke_function("distributed_sequence_id")
+ else:
+ from pyspark import SparkContext
+
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())
+
+
def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
@@ -187,6 +199,19 @@ def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse))
+def binary_search(col: Column, value: Column) -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns
+
+ return _invoke_function_over_columns("array_binary_search", col, value)
+
+ else:
+ from pyspark import SparkContext
+
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc))
+
+
def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
index 37469db2c8f51..8d197649aaebe 100644
--- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
+++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py
@@ -105,9 +105,10 @@ def check_barh_plot_with_x_y(pdf, psdf, x, y):
self.assertEqual(pdf.plot.barh(x=x, y=y), psdf.plot.barh(x=x, y=y))
# this is testing plot with specified x and y
- pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20]})
+ pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20], "val2": [1.1, 2.2, 3.3]})
psdf1 = ps.from_pandas(pdf1)
- check_barh_plot_with_x_y(pdf1, psdf1, x="lab", y="val")
+ check_barh_plot_with_x_y(pdf1, psdf1, x="val", y="lab")
+ check_barh_plot_with_x_y(pdf1, psdf1, x=["val", "val2"], y="lab")
def test_barh_plot(self):
def check_barh_plot(pdf, psdf):
diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py
index 0e890e3343e66..a2778cbc32c4c 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -73,6 +73,11 @@
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
+try:
+ from pyspark.sql.plot import PySparkPlotAccessor
+except ImportError:
+ PySparkPlotAccessor = None # type: ignore
+
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
@@ -1849,6 +1854,12 @@ def toArrow(self) -> "pa.Table":
def toPandas(self) -> "PandasDataFrameLike":
return PandasConversionMixin.toPandas(self)
+ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame:
+ if indexColumn is not None:
+ return DataFrame(self._jdf.transpose(_to_java_column(indexColumn)), self.sparkSession)
+ else:
+ return DataFrame(self._jdf.transpose(), self.sparkSession)
+
@property
def executionInfo(self) -> Optional["ExecutionInfo"]:
raise PySparkValueError(
@@ -1856,6 +1867,10 @@ def executionInfo(self) -> Optional["ExecutionInfo"]:
messageParameters={"member": "queryExecution"},
)
+ @property
+ def plot(self) -> PySparkPlotAccessor:
+ return PySparkPlotAccessor(self)
+
class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py
index 35dcf677fdb70..adba1b42a8bd6 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1786,7 +1786,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
req.user_context.user_id = self._user_id
try:
- return self._stub.FetchErrorDetails(req)
+ return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
except grpc.RpcError:
return None
diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py
index 0309a25f956a1..ea6788e858317 100644
--- a/python/pyspark/sql/connect/client/reattach.py
+++ b/python/pyspark/sql/connect/client/reattach.py
@@ -284,8 +284,14 @@ def _call_iter(self, iter_fun: Callable) -> Any:
raise e
def _create_reattach_execute_request(self) -> pb2.ReattachExecuteRequest:
+ server_side_session_id = (
+ None
+ if not self._initial_request.client_observed_server_side_session_id
+ else self._initial_request.client_observed_server_side_session_id
+ )
reattach = pb2.ReattachExecuteRequest(
session_id=self._initial_request.session_id,
+ client_observed_server_side_session_id=server_side_session_id,
user_context=self._initial_request.user_context,
operation_id=self._initial_request.operation_id,
)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 442157eef0b75..59d79decf6690 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -86,6 +86,10 @@
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined]
+try:
+ from pyspark.sql.plot import PySparkPlotAccessor
+except ImportError:
+ PySparkPlotAccessor = None # type: ignore
if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
@@ -1783,7 +1787,7 @@ def __getitem__(
)
)
else:
- # TODO: revisit vanilla Spark's Dataset.col
+ # TODO: revisit classic Spark's Dataset.col
# if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) {
# colRegex(colName)
# } else {
@@ -1858,6 +1862,12 @@ def toPandas(self) -> "PandasDataFrameLike":
self._execution_info = ei
return pdf
+ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame:
+ return DataFrame(
+ plan.Transpose(self._plan, [F._to_col(indexColumn)] if indexColumn is not None else []),
+ self._session,
+ )
+
@property
def schema(self) -> StructType:
# Schema caching is correct in most cases. Connect is lazy by nature. This means that
@@ -2233,6 +2243,10 @@ def rdd(self) -> "RDD[Row]":
def executionInfo(self) -> Optional["ExecutionInfo"]:
return self._execution_info
+ @property
+ def plot(self) -> PySparkPlotAccessor:
+ return PySparkPlotAccessor(self)
+
class DataFrameNaFunctions(ParentDataFrameNaFunctions):
def __init__(self, df: ParentDataFrame):
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index db1cd1c013be5..0b5512b61925c 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -477,8 +477,30 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
def __repr__(self) -> str:
if self._value is None:
return "NULL"
- else:
- return f"{self._value}"
+ elif isinstance(self._dataType, DateType):
+ dt = DateType().fromInternal(self._value)
+ if dt is not None and isinstance(dt, datetime.date):
+ return dt.strftime("%Y-%m-%d")
+ elif isinstance(self._dataType, TimestampType):
+ ts = TimestampType().fromInternal(self._value)
+ if ts is not None and isinstance(ts, datetime.datetime):
+ return ts.strftime("%Y-%m-%d %H:%M:%S.%f")
+ elif isinstance(self._dataType, TimestampNTZType):
+ ts = TimestampNTZType().fromInternal(self._value)
+ if ts is not None and isinstance(ts, datetime.datetime):
+ return ts.strftime("%Y-%m-%d %H:%M:%S.%f")
+ elif isinstance(self._dataType, DayTimeIntervalType):
+ delta = DayTimeIntervalType().fromInternal(self._value)
+ if delta is not None and isinstance(delta, datetime.timedelta):
+ import pandas as pd
+
+ # Note: timedelta itself does not provide isoformat method.
+ # Both Pandas and java.time.Duration provide it, but the format
+ # is sightly different:
+ # java.time.Duration only applies HOURS, MINUTES, SECONDS units,
+ # while Pandas applies all supported units.
+ return pd.Timedelta(delta).isoformat() # type: ignore[attr-defined]
+ return f"{self._value}"
class ColumnReference(Expression):
diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py
index ad6dbbf58e48d..7fed175cbc8ea 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -71,6 +71,7 @@
StringType,
)
from pyspark.sql.utils import enum_to_value as _enum_to_value
+from pyspark.util import JVM_INT_MAX
# The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf
# for code reuse.
@@ -1126,11 +1127,12 @@ def grouping_id(*cols: "ColumnOrName") -> Column:
def count_min_sketch(
col: "ColumnOrName",
- eps: "ColumnOrName",
- confidence: "ColumnOrName",
- seed: "ColumnOrName",
+ eps: Union[Column, float],
+ confidence: Union[Column, float],
+ seed: Optional[Union[Column, int]] = None,
) -> Column:
- return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
+ _seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed)
+ return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed)
count_min_sketch.__doc__ = pysparkfuncs.count_min_sketch.__doc__
@@ -2071,6 +2073,13 @@ def try_parse_json(col: "ColumnOrName") -> Column:
try_parse_json.__doc__ = pysparkfuncs.try_parse_json.__doc__
+def to_variant_object(col: "ColumnOrName") -> Column:
+ return _invoke_function("to_variant_object", _to_col(col))
+
+
+to_variant_object.__doc__ = pysparkfuncs.to_variant_object.__doc__
+
+
def parse_json(col: "ColumnOrName") -> Column:
return _invoke_function("parse_json", _to_col(col))
@@ -2481,8 +2490,14 @@ def sentences(
sentences.__doc__ = pysparkfuncs.sentences.__doc__
-def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
- return _invoke_function("substring", _to_col(str), lit(pos), lit(len))
+def substring(
+ str: "ColumnOrName",
+ pos: Union["ColumnOrName", int],
+ len: Union["ColumnOrName", int],
+) -> Column:
+ _pos = lit(pos) if isinstance(pos, int) else _to_col(pos)
+ _len = lit(len) if isinstance(len, int) else _to_col(len)
+ return _invoke_function("substring", _to_col(str), _pos, _len)
substring.__doc__ = pysparkfuncs.substring.__doc__
diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py
index 69fbcda12ae72..46f13f893c7fa 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -101,7 +101,7 @@ def __init__(
def __repr__(self) -> str:
# the expressions are not resolved here,
- # so the string representation can be different from vanilla PySpark.
+ # so the string representation can be different from classic PySpark.
grouping_str = ", ".join(str(e._expr) for e in self._grouping_cols)
grouping_str = f"grouping expressions: [{grouping_str}]"
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 958626280e41c..fbed0eabc684f 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1329,6 +1329,27 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan
+class Transpose(LogicalPlan):
+ """Logical plan object for a transpose operation."""
+
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ index_columns: Sequence[Column],
+ ) -> None:
+ super().__init__(child)
+ self.index_columns = index_columns
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+ plan = self._create_proto_relation()
+ plan.transpose.input.CopyFrom(self._child.plan(session))
+ if self.index_columns is not None and len(self.index_columns) > 0:
+ for index_column in self.index_columns:
+ plan.transpose.index_columns.append(index_column.to_plan(session))
+ return plan
+
+
class CollectMetrics(LogicalPlan):
"""Logical plan object for a CollectMetrics operation."""
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 9f4d1e717a28d..ee625241600ff 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xe9\x1a\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xfb\x04\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_conf"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirectionB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xa3\x1b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x38\n\ttranspose\x18* \x01(\x0b\x32\x18.spark.connect.TransposeH\x00R\ttranspose\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"z\n\tTranspose\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\rindex_columns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cindexColumns"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xfb\x04\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_conf"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirectionB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -69,153 +69,155 @@
_PARSE_OPTIONSENTRY._options = None
_PARSE_OPTIONSENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 193
- _RELATION._serialized_end = 3626
- _UNKNOWN._serialized_start = 3628
- _UNKNOWN._serialized_end = 3637
- _RELATIONCOMMON._serialized_start = 3640
- _RELATIONCOMMON._serialized_end = 3782
- _SQL._serialized_start = 3785
- _SQL._serialized_end = 4263
- _SQL_ARGSENTRY._serialized_start = 4079
- _SQL_ARGSENTRY._serialized_end = 4169
- _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4171
- _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4263
- _WITHRELATIONS._serialized_start = 4265
- _WITHRELATIONS._serialized_end = 4382
- _READ._serialized_start = 4385
- _READ._serialized_end = 5048
- _READ_NAMEDTABLE._serialized_start = 4563
- _READ_NAMEDTABLE._serialized_end = 4755
- _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4697
- _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4755
- _READ_DATASOURCE._serialized_start = 4758
- _READ_DATASOURCE._serialized_end = 5035
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4697
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4755
- _PROJECT._serialized_start = 5050
- _PROJECT._serialized_end = 5167
- _FILTER._serialized_start = 5169
- _FILTER._serialized_end = 5281
- _JOIN._serialized_start = 5284
- _JOIN._serialized_end = 5945
- _JOIN_JOINDATATYPE._serialized_start = 5623
- _JOIN_JOINDATATYPE._serialized_end = 5715
- _JOIN_JOINTYPE._serialized_start = 5718
- _JOIN_JOINTYPE._serialized_end = 5926
- _SETOPERATION._serialized_start = 5948
- _SETOPERATION._serialized_end = 6427
- _SETOPERATION_SETOPTYPE._serialized_start = 6264
- _SETOPERATION_SETOPTYPE._serialized_end = 6378
- _LIMIT._serialized_start = 6429
- _LIMIT._serialized_end = 6505
- _OFFSET._serialized_start = 6507
- _OFFSET._serialized_end = 6586
- _TAIL._serialized_start = 6588
- _TAIL._serialized_end = 6663
- _AGGREGATE._serialized_start = 6666
- _AGGREGATE._serialized_end = 7432
- _AGGREGATE_PIVOT._serialized_start = 7081
- _AGGREGATE_PIVOT._serialized_end = 7192
- _AGGREGATE_GROUPINGSETS._serialized_start = 7194
- _AGGREGATE_GROUPINGSETS._serialized_end = 7270
- _AGGREGATE_GROUPTYPE._serialized_start = 7273
- _AGGREGATE_GROUPTYPE._serialized_end = 7432
- _SORT._serialized_start = 7435
- _SORT._serialized_end = 7595
- _DROP._serialized_start = 7598
- _DROP._serialized_end = 7739
- _DEDUPLICATE._serialized_start = 7742
- _DEDUPLICATE._serialized_end = 7982
- _LOCALRELATION._serialized_start = 7984
- _LOCALRELATION._serialized_end = 8073
- _CACHEDLOCALRELATION._serialized_start = 8075
- _CACHEDLOCALRELATION._serialized_end = 8147
- _CACHEDREMOTERELATION._serialized_start = 8149
- _CACHEDREMOTERELATION._serialized_end = 8204
- _SAMPLE._serialized_start = 8207
- _SAMPLE._serialized_end = 8480
- _RANGE._serialized_start = 8483
- _RANGE._serialized_end = 8628
- _SUBQUERYALIAS._serialized_start = 8630
- _SUBQUERYALIAS._serialized_end = 8744
- _REPARTITION._serialized_start = 8747
- _REPARTITION._serialized_end = 8889
- _SHOWSTRING._serialized_start = 8892
- _SHOWSTRING._serialized_end = 9034
- _HTMLSTRING._serialized_start = 9036
- _HTMLSTRING._serialized_end = 9150
- _STATSUMMARY._serialized_start = 9152
- _STATSUMMARY._serialized_end = 9244
- _STATDESCRIBE._serialized_start = 9246
- _STATDESCRIBE._serialized_end = 9327
- _STATCROSSTAB._serialized_start = 9329
- _STATCROSSTAB._serialized_end = 9430
- _STATCOV._serialized_start = 9432
- _STATCOV._serialized_end = 9528
- _STATCORR._serialized_start = 9531
- _STATCORR._serialized_end = 9668
- _STATAPPROXQUANTILE._serialized_start = 9671
- _STATAPPROXQUANTILE._serialized_end = 9835
- _STATFREQITEMS._serialized_start = 9837
- _STATFREQITEMS._serialized_end = 9962
- _STATSAMPLEBY._serialized_start = 9965
- _STATSAMPLEBY._serialized_end = 10274
- _STATSAMPLEBY_FRACTION._serialized_start = 10166
- _STATSAMPLEBY_FRACTION._serialized_end = 10265
- _NAFILL._serialized_start = 10277
- _NAFILL._serialized_end = 10411
- _NADROP._serialized_start = 10414
- _NADROP._serialized_end = 10548
- _NAREPLACE._serialized_start = 10551
- _NAREPLACE._serialized_end = 10847
- _NAREPLACE_REPLACEMENT._serialized_start = 10706
- _NAREPLACE_REPLACEMENT._serialized_end = 10847
- _TODF._serialized_start = 10849
- _TODF._serialized_end = 10937
- _WITHCOLUMNSRENAMED._serialized_start = 10940
- _WITHCOLUMNSRENAMED._serialized_end = 11322
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 11184
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11251
- _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11253
- _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11322
- _WITHCOLUMNS._serialized_start = 11324
- _WITHCOLUMNS._serialized_end = 11443
- _WITHWATERMARK._serialized_start = 11446
- _WITHWATERMARK._serialized_end = 11580
- _HINT._serialized_start = 11583
- _HINT._serialized_end = 11715
- _UNPIVOT._serialized_start = 11718
- _UNPIVOT._serialized_end = 12045
- _UNPIVOT_VALUES._serialized_start = 11975
- _UNPIVOT_VALUES._serialized_end = 12034
- _TOSCHEMA._serialized_start = 12047
- _TOSCHEMA._serialized_end = 12153
- _REPARTITIONBYEXPRESSION._serialized_start = 12156
- _REPARTITIONBYEXPRESSION._serialized_end = 12359
- _MAPPARTITIONS._serialized_start = 12362
- _MAPPARTITIONS._serialized_end = 12594
- _GROUPMAP._serialized_start = 12597
- _GROUPMAP._serialized_end = 13232
- _COGROUPMAP._serialized_start = 13235
- _COGROUPMAP._serialized_end = 13761
- _APPLYINPANDASWITHSTATE._serialized_start = 13764
- _APPLYINPANDASWITHSTATE._serialized_end = 14121
- _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 14124
- _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14368
- _PYTHONUDTF._serialized_start = 14371
- _PYTHONUDTF._serialized_end = 14548
- _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14551
- _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 14702
- _PYTHONDATASOURCE._serialized_start = 14704
- _PYTHONDATASOURCE._serialized_end = 14779
- _COLLECTMETRICS._serialized_start = 14782
- _COLLECTMETRICS._serialized_end = 14918
- _PARSE._serialized_start = 14921
- _PARSE._serialized_end = 15309
- _PARSE_OPTIONSENTRY._serialized_start = 4697
- _PARSE_OPTIONSENTRY._serialized_end = 4755
- _PARSE_PARSEFORMAT._serialized_start = 15210
- _PARSE_PARSEFORMAT._serialized_end = 15298
- _ASOFJOIN._serialized_start = 15312
- _ASOFJOIN._serialized_end = 15787
+ _RELATION._serialized_end = 3684
+ _UNKNOWN._serialized_start = 3686
+ _UNKNOWN._serialized_end = 3695
+ _RELATIONCOMMON._serialized_start = 3698
+ _RELATIONCOMMON._serialized_end = 3840
+ _SQL._serialized_start = 3843
+ _SQL._serialized_end = 4321
+ _SQL_ARGSENTRY._serialized_start = 4137
+ _SQL_ARGSENTRY._serialized_end = 4227
+ _SQL_NAMEDARGUMENTSENTRY._serialized_start = 4229
+ _SQL_NAMEDARGUMENTSENTRY._serialized_end = 4321
+ _WITHRELATIONS._serialized_start = 4323
+ _WITHRELATIONS._serialized_end = 4440
+ _READ._serialized_start = 4443
+ _READ._serialized_end = 5106
+ _READ_NAMEDTABLE._serialized_start = 4621
+ _READ_NAMEDTABLE._serialized_end = 4813
+ _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 4755
+ _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4813
+ _READ_DATASOURCE._serialized_start = 4816
+ _READ_DATASOURCE._serialized_end = 5093
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 4755
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4813
+ _PROJECT._serialized_start = 5108
+ _PROJECT._serialized_end = 5225
+ _FILTER._serialized_start = 5227
+ _FILTER._serialized_end = 5339
+ _JOIN._serialized_start = 5342
+ _JOIN._serialized_end = 6003
+ _JOIN_JOINDATATYPE._serialized_start = 5681
+ _JOIN_JOINDATATYPE._serialized_end = 5773
+ _JOIN_JOINTYPE._serialized_start = 5776
+ _JOIN_JOINTYPE._serialized_end = 5984
+ _SETOPERATION._serialized_start = 6006
+ _SETOPERATION._serialized_end = 6485
+ _SETOPERATION_SETOPTYPE._serialized_start = 6322
+ _SETOPERATION_SETOPTYPE._serialized_end = 6436
+ _LIMIT._serialized_start = 6487
+ _LIMIT._serialized_end = 6563
+ _OFFSET._serialized_start = 6565
+ _OFFSET._serialized_end = 6644
+ _TAIL._serialized_start = 6646
+ _TAIL._serialized_end = 6721
+ _AGGREGATE._serialized_start = 6724
+ _AGGREGATE._serialized_end = 7490
+ _AGGREGATE_PIVOT._serialized_start = 7139
+ _AGGREGATE_PIVOT._serialized_end = 7250
+ _AGGREGATE_GROUPINGSETS._serialized_start = 7252
+ _AGGREGATE_GROUPINGSETS._serialized_end = 7328
+ _AGGREGATE_GROUPTYPE._serialized_start = 7331
+ _AGGREGATE_GROUPTYPE._serialized_end = 7490
+ _SORT._serialized_start = 7493
+ _SORT._serialized_end = 7653
+ _DROP._serialized_start = 7656
+ _DROP._serialized_end = 7797
+ _DEDUPLICATE._serialized_start = 7800
+ _DEDUPLICATE._serialized_end = 8040
+ _LOCALRELATION._serialized_start = 8042
+ _LOCALRELATION._serialized_end = 8131
+ _CACHEDLOCALRELATION._serialized_start = 8133
+ _CACHEDLOCALRELATION._serialized_end = 8205
+ _CACHEDREMOTERELATION._serialized_start = 8207
+ _CACHEDREMOTERELATION._serialized_end = 8262
+ _SAMPLE._serialized_start = 8265
+ _SAMPLE._serialized_end = 8538
+ _RANGE._serialized_start = 8541
+ _RANGE._serialized_end = 8686
+ _SUBQUERYALIAS._serialized_start = 8688
+ _SUBQUERYALIAS._serialized_end = 8802
+ _REPARTITION._serialized_start = 8805
+ _REPARTITION._serialized_end = 8947
+ _SHOWSTRING._serialized_start = 8950
+ _SHOWSTRING._serialized_end = 9092
+ _HTMLSTRING._serialized_start = 9094
+ _HTMLSTRING._serialized_end = 9208
+ _STATSUMMARY._serialized_start = 9210
+ _STATSUMMARY._serialized_end = 9302
+ _STATDESCRIBE._serialized_start = 9304
+ _STATDESCRIBE._serialized_end = 9385
+ _STATCROSSTAB._serialized_start = 9387
+ _STATCROSSTAB._serialized_end = 9488
+ _STATCOV._serialized_start = 9490
+ _STATCOV._serialized_end = 9586
+ _STATCORR._serialized_start = 9589
+ _STATCORR._serialized_end = 9726
+ _STATAPPROXQUANTILE._serialized_start = 9729
+ _STATAPPROXQUANTILE._serialized_end = 9893
+ _STATFREQITEMS._serialized_start = 9895
+ _STATFREQITEMS._serialized_end = 10020
+ _STATSAMPLEBY._serialized_start = 10023
+ _STATSAMPLEBY._serialized_end = 10332
+ _STATSAMPLEBY_FRACTION._serialized_start = 10224
+ _STATSAMPLEBY_FRACTION._serialized_end = 10323
+ _NAFILL._serialized_start = 10335
+ _NAFILL._serialized_end = 10469
+ _NADROP._serialized_start = 10472
+ _NADROP._serialized_end = 10606
+ _NAREPLACE._serialized_start = 10609
+ _NAREPLACE._serialized_end = 10905
+ _NAREPLACE_REPLACEMENT._serialized_start = 10764
+ _NAREPLACE_REPLACEMENT._serialized_end = 10905
+ _TODF._serialized_start = 10907
+ _TODF._serialized_end = 10995
+ _WITHCOLUMNSRENAMED._serialized_start = 10998
+ _WITHCOLUMNSRENAMED._serialized_end = 11380
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 11242
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 11309
+ _WITHCOLUMNSRENAMED_RENAME._serialized_start = 11311
+ _WITHCOLUMNSRENAMED_RENAME._serialized_end = 11380
+ _WITHCOLUMNS._serialized_start = 11382
+ _WITHCOLUMNS._serialized_end = 11501
+ _WITHWATERMARK._serialized_start = 11504
+ _WITHWATERMARK._serialized_end = 11638
+ _HINT._serialized_start = 11641
+ _HINT._serialized_end = 11773
+ _UNPIVOT._serialized_start = 11776
+ _UNPIVOT._serialized_end = 12103
+ _UNPIVOT_VALUES._serialized_start = 12033
+ _UNPIVOT_VALUES._serialized_end = 12092
+ _TRANSPOSE._serialized_start = 12105
+ _TRANSPOSE._serialized_end = 12227
+ _TOSCHEMA._serialized_start = 12229
+ _TOSCHEMA._serialized_end = 12335
+ _REPARTITIONBYEXPRESSION._serialized_start = 12338
+ _REPARTITIONBYEXPRESSION._serialized_end = 12541
+ _MAPPARTITIONS._serialized_start = 12544
+ _MAPPARTITIONS._serialized_end = 12776
+ _GROUPMAP._serialized_start = 12779
+ _GROUPMAP._serialized_end = 13414
+ _COGROUPMAP._serialized_start = 13417
+ _COGROUPMAP._serialized_end = 13943
+ _APPLYINPANDASWITHSTATE._serialized_start = 13946
+ _APPLYINPANDASWITHSTATE._serialized_end = 14303
+ _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 14306
+ _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 14550
+ _PYTHONUDTF._serialized_start = 14553
+ _PYTHONUDTF._serialized_end = 14730
+ _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_start = 14733
+ _COMMONINLINEUSERDEFINEDDATASOURCE._serialized_end = 14884
+ _PYTHONDATASOURCE._serialized_start = 14886
+ _PYTHONDATASOURCE._serialized_end = 14961
+ _COLLECTMETRICS._serialized_start = 14964
+ _COLLECTMETRICS._serialized_end = 15100
+ _PARSE._serialized_start = 15103
+ _PARSE._serialized_end = 15491
+ _PARSE_OPTIONSENTRY._serialized_start = 4755
+ _PARSE_OPTIONSENTRY._serialized_end = 4813
+ _PARSE_PARSEFORMAT._serialized_start = 15392
+ _PARSE_PARSEFORMAT._serialized_end = 15480
+ _ASOFJOIN._serialized_start = 15494
+ _ASOFJOIN._serialized_end = 15969
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 864803fd33084..b1cd2e184d085 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -104,6 +104,7 @@ class Relation(google.protobuf.message.Message):
AS_OF_JOIN_FIELD_NUMBER: builtins.int
COMMON_INLINE_USER_DEFINED_DATA_SOURCE_FIELD_NUMBER: builtins.int
WITH_RELATIONS_FIELD_NUMBER: builtins.int
+ TRANSPOSE_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -205,6 +206,8 @@ class Relation(google.protobuf.message.Message):
@property
def with_relations(self) -> global___WithRelations: ...
@property
+ def transpose(self) -> global___Transpose: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -284,6 +287,7 @@ class Relation(google.protobuf.message.Message):
common_inline_user_defined_data_source: global___CommonInlineUserDefinedDataSource
| None = ...,
with_relations: global___WithRelations | None = ...,
+ transpose: global___Transpose | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -402,6 +406,8 @@ class Relation(google.protobuf.message.Message):
b"to_df",
"to_schema",
b"to_schema",
+ "transpose",
+ b"transpose",
"unknown",
b"unknown",
"unpivot",
@@ -519,6 +525,8 @@ class Relation(google.protobuf.message.Message):
b"to_df",
"to_schema",
b"to_schema",
+ "transpose",
+ b"transpose",
"unknown",
b"unknown",
"unpivot",
@@ -577,6 +585,7 @@ class Relation(google.protobuf.message.Message):
"as_of_join",
"common_inline_user_defined_data_source",
"with_relations",
+ "transpose",
"fill_na",
"drop_na",
"replace",
@@ -3141,6 +3150,47 @@ class Unpivot(google.protobuf.message.Message):
global___Unpivot = Unpivot
+class Transpose(google.protobuf.message.Message):
+ """Transpose a DataFrame, switching rows to columns.
+ Transforms the DataFrame such that the values in the specified index column
+ become the new columns of the DataFrame.
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ INDEX_COLUMNS_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ @property
+ def index_columns(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Optional) A list of columns that will be treated as the indices.
+ Only single column is supported now.
+ """
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ index_columns: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]
+ | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["input", b"input"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal["index_columns", b"index_columns", "input", b"input"],
+ ) -> None: ...
+
+global___Transpose = Transpose
+
class ToSchema(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index e5246e893f658..cacb479229bb7 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -19,6 +19,7 @@
check_dependencies(__name__)
+import json
import threading
import os
import warnings
@@ -200,6 +201,26 @@ def enableHiveSupport(self) -> "SparkSession.Builder":
)
def _apply_options(self, session: "SparkSession") -> None:
+ init_opts = {}
+ for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))):
+ init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"])
+
+ with self._lock:
+ for k, v in init_opts.items():
+ # the options are applied after session creation,
+ # so following options always take no effect
+ if k not in [
+ "spark.remote",
+ "spark.master",
+ ] and k.startswith("spark.sql."):
+ # Only attempts to set Spark SQL configurations.
+ # If the configurations are static, it might throw an exception so
+ # simply ignore it for now.
+ try:
+ session.conf.set(k, v)
+ except Exception:
+ pass
+
with self._lock:
for k, v in self._options.items():
# the options are applied after session creation,
@@ -993,10 +1014,17 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
session = PySparkSession._instantiatedSession
if session is None or session._sc._jsc is None:
+ init_opts = {}
+ for i in range(int(os.environ.get("PYSPARK_REMOTE_INIT_CONF_LEN", "0"))):
+ init_opts = json.loads(os.environ[f"PYSPARK_REMOTE_INIT_CONF_{i}"])
+ init_opts.update(opts)
+ opts = init_opts
+
# Configurations to be overwritten
overwrite_conf = opts
overwrite_conf["spark.master"] = master
overwrite_conf["spark.local.connect"] = "1"
+ os.environ["SPARK_LOCAL_CONNECT"] = "1"
# Configurations to be set if unset.
default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"}
@@ -1030,6 +1058,7 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
+ del os.environ["SPARK_LOCAL_CONNECT"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7d3900c7afbc5..2179a844b1e5e 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -43,6 +43,7 @@
from pyspark.sql.types import StructType, Row
from pyspark.sql.utils import dispatch_df_method
+
if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
import pyarrow as pa
@@ -65,6 +66,7 @@
ArrowMapIterFunction,
DataFrameLike as PandasDataFrameLike,
)
+ from pyspark.sql.plot import PySparkPlotAccessor
from pyspark.sql.metrics import ExecutionInfo
@@ -1332,7 +1334,7 @@ def offset(self, num: int) -> "DataFrame":
.. versionadded:: 3.4.0
.. versionchanged:: 3.5.0
- Supports vanilla PySpark.
+ Supports classic PySpark.
Parameters
----------
@@ -6168,10 +6170,6 @@ def mapInPandas(
| 1| 21|
+---+---+
- Notes
- -----
- This API is experimental
-
See Also
--------
pyspark.sql.functions.pandas_udf
@@ -6245,10 +6243,6 @@ def mapInArrow(
| 1| 21|
+---+---+
- Notes
- -----
- This API is unstable, and for developers.
-
See Also
--------
pyspark.sql.functions.pandas_udf
@@ -6311,6 +6305,72 @@ def toPandas(self) -> "PandasDataFrameLike":
"""
...
+ @dispatch_df_method
+ def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> "DataFrame":
+ """
+ Transposes a DataFrame such that the values in the specified index column become the new
+ columns of the DataFrame. If no index column is provided, the first column is used as
+ the default.
+
+ Please note:
+ - All columns except the index column must share a least common data type. Unless they
+ are the same data type, all columns are cast to the nearest common data type.
+ - The name of the column into which the original column names are transposed defaults
+ to "key".
+ - null values in the index column are excluded from the column names for the
+ transposed table, which are ordered in ascending order.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ indexColumn : str or :class:`Column`, optional
+ The single column that will be treated as the index for the transpose operation. This
+ column will be used to transform the DataFrame such that the values of the indexColumn
+ become the new columns in the transposed DataFrame. If not provided, the first column of
+ the DataFrame will be used as the default.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ Transposed DataFrame.
+
+ Notes
+ -----
+ Supports Spark Connect.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(
+ ... [("A", 1, 2), ("B", 3, 4)],
+ ... ["id", "val1", "val2"],
+ ... )
+ >>> df.show()
+ +---+----+----+
+ | id|val1|val2|
+ +---+----+----+
+ | A| 1| 2|
+ | B| 3| 4|
+ +---+----+----+
+
+ >>> df.transpose().show()
+ +----+---+---+
+ | key| A| B|
+ +----+---+---+
+ |val1| 1| 3|
+ |val2| 2| 4|
+ +----+---+---+
+
+ >>> df.transpose(df.id).show()
+ +----+---+---+
+ | key| A| B|
+ +----+---+---+
+ |val1| 1| 3|
+ |val2| 2| 4|
+ +----+---+---+
+ """
+ ...
+
@property
def executionInfo(self) -> Optional["ExecutionInfo"]:
"""
@@ -6336,6 +6396,32 @@ def executionInfo(self) -> Optional["ExecutionInfo"]:
"""
...
+ @property
+ def plot(self) -> "PySparkPlotAccessor":
+ """
+ Returns a :class:`PySparkPlotAccessor` for plotting functions.
+
+ .. versionadded:: 4.0.0
+
+ Returns
+ -------
+ :class:`PySparkPlotAccessor`
+
+ Notes
+ -----
+ This API is experimental.
+
+ Examples
+ --------
+ >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+ >>> columns = ["category", "int_val", "float_val"]
+ >>> df = spark.createDataFrame(data, columns)
+ >>> type(df.plot)
+
+ >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
+ """
+ ...
+
class DataFrameNaFunctions:
"""Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 72d42ae5e0cdf..a51c96a9d178f 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -23,9 +23,9 @@
from pyspark.errors import PySparkNotImplementedError
if TYPE_CHECKING:
+ from pyarrow import RecordBatch
from pyspark.sql.session import SparkSession
-
__all__ = [
"DataSource",
"DataSourceReader",
@@ -333,7 +333,7 @@ def partitions(self) -> Sequence[InputPartition]:
)
@abstractmethod
- def read(self, partition: InputPartition) -> Iterator[Tuple]:
+ def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["RecordBatch"]]:
"""
Generates data for a given partition and returns an iterator of tuples or rows.
@@ -350,9 +350,11 @@ def read(self, partition: InputPartition) -> Iterator[Tuple]:
Returns
-------
- iterator of tuples or :class:`Row`\\s
+ iterator of tuples or PyArrow's `RecordBatch`
An iterator of tuples or rows. Each tuple or row will be converted to a row
in the final DataFrame.
+ It can also return an iterator of PyArrow's `RecordBatch` if the data source
+ supports it.
Examples
--------
@@ -448,7 +450,7 @@ def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]:
)
@abstractmethod
- def read(self, partition: InputPartition) -> Iterator[Tuple]:
+ def read(self, partition: InputPartition) -> Union[Iterator[Tuple], Iterator["RecordBatch"]]:
"""
Generates data for a given partition and returns an iterator of tuples or rows.
@@ -470,9 +472,11 @@ def read(self, partition: InputPartition) -> Iterator[Tuple]:
Returns
-------
- iterator of tuples or :class:`Row`\\s
+ iterator of tuples or PyArrow's `RecordBatch`
An iterator of tuples or rows. Each tuple or row will be converted to a row
in the final DataFrame.
+ It can also return an iterator of PyArrow's `RecordBatch` if the data source
+ supports it.
"""
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py
index 344ba8d009ac4..5f8d1c21a24f1 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -6015,9 +6015,9 @@ def grouping_id(*cols: "ColumnOrName") -> Column:
@_try_remote_functions
def count_min_sketch(
col: "ColumnOrName",
- eps: "ColumnOrName",
- confidence: "ColumnOrName",
- seed: "ColumnOrName",
+ eps: Union[Column, float],
+ confidence: Union[Column, float],
+ seed: Optional[Union[Column, int]] = None,
) -> Column:
"""
Returns a count-min sketch of a column with the given esp, confidence and seed.
@@ -6031,13 +6031,24 @@ def count_min_sketch(
----------
col : :class:`~pyspark.sql.Column` or str
target column to compute on.
- eps : :class:`~pyspark.sql.Column` or str
+ eps : :class:`~pyspark.sql.Column` or float
relative error, must be positive
- confidence : :class:`~pyspark.sql.Column` or str
+
+ .. versionchanged:: 4.0.0
+ `eps` now accepts float value.
+
+ confidence : :class:`~pyspark.sql.Column` or float
confidence, must be positive and less than 1.0
- seed : :class:`~pyspark.sql.Column` or str
+
+ .. versionchanged:: 4.0.0
+ `confidence` now accepts float value.
+
+ seed : :class:`~pyspark.sql.Column` or int, optional
random seed
+ .. versionchanged:: 4.0.0
+ `seed` now accepts int value.
+
Returns
-------
:class:`~pyspark.sql.Column`
@@ -6045,12 +6056,48 @@ def count_min_sketch(
Examples
--------
- >>> df = spark.createDataFrame([[1], [2], [1]], ['data'])
- >>> df = df.agg(count_min_sketch(df.data, lit(0.5), lit(0.5), lit(1)).alias('sketch'))
- >>> df.select(hex(df.sketch).alias('r')).collect()
- [Row(r='0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000')]
- """
- return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
+ Example 1: Using columns as arguments
+
+ >>> from pyspark.sql import functions as sf
+ >>> spark.range(100).select(
+ ... sf.hex(sf.count_min_sketch(sf.col("id"), sf.lit(3.0), sf.lit(0.1), sf.lit(1)))
+ ... ).show(truncate=False)
+ +------------------------------------------------------------------------+
+ |hex(count_min_sketch(id, 3.0, 0.1, 1)) |
+ +------------------------------------------------------------------------+
+ |0000000100000000000000640000000100000001000000005D8D6AB90000000000000064|
+ +------------------------------------------------------------------------+
+
+ Example 2: Using numbers as arguments
+
+ >>> from pyspark.sql import functions as sf
+ >>> spark.range(100).select(
+ ... sf.hex(sf.count_min_sketch("id", 1.0, 0.3, 2))
+ ... ).show(truncate=False)
+ +----------------------------------------------------------------------------------------+
+ |hex(count_min_sketch(id, 1.0, 0.3, 2)) |
+ +----------------------------------------------------------------------------------------+
+ |0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032|
+ +----------------------------------------------------------------------------------------+
+
+ Example 3: Using a random seed
+
+ >>> from pyspark.sql import functions as sf
+ >>> spark.range(100).select(
+ ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6))
+ ... ).show(truncate=False) # doctest: +SKIP
+ +----------------------------------------------------------------------------------------------------------------------------------------+
+ |hex(count_min_sketch(id, 1.5, 0.6, 2120704260)) |
+ +----------------------------------------------------------------------------------------------------------------------------------------+
+ |0000000100000000000000640000000200000002000000005ADECCEE00000000153EBE090000000000000033000000000000003100000000000000320000000000000032|
+ +----------------------------------------------------------------------------------------------------------------------------------------+
+ """ # noqa: E501
+ _eps = lit(eps)
+ _conf = lit(confidence)
+ if seed is None:
+ return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf)
+ else:
+ return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf, lit(seed))
@_try_remote_functions
@@ -7305,36 +7352,36 @@ def lag(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) ->
| b| 2|
+---+---+
>>> w = Window.partitionBy("c1").orderBy("c2")
- >>> df.withColumn("previos_value", lag("c2").over(w)).show()
- +---+---+-------------+
- | c1| c2|previos_value|
- +---+---+-------------+
- | a| 1| NULL|
- | a| 2| 1|
- | a| 3| 2|
- | b| 2| NULL|
- | b| 8| 2|
- +---+---+-------------+
- >>> df.withColumn("previos_value", lag("c2", 1, 0).over(w)).show()
- +---+---+-------------+
- | c1| c2|previos_value|
- +---+---+-------------+
- | a| 1| 0|
- | a| 2| 1|
- | a| 3| 2|
- | b| 2| 0|
- | b| 8| 2|
- +---+---+-------------+
- >>> df.withColumn("previos_value", lag("c2", 2, -1).over(w)).show()
- +---+---+-------------+
- | c1| c2|previos_value|
- +---+---+-------------+
- | a| 1| -1|
- | a| 2| -1|
- | a| 3| 1|
- | b| 2| -1|
- | b| 8| -1|
- +---+---+-------------+
+ >>> df.withColumn("previous_value", lag("c2").over(w)).show()
+ +---+---+--------------+
+ | c1| c2|previous_value|
+ +---+---+--------------+
+ | a| 1| NULL|
+ | a| 2| 1|
+ | a| 3| 2|
+ | b| 2| NULL|
+ | b| 8| 2|
+ +---+---+--------------+
+ >>> df.withColumn("previous_value", lag("c2", 1, 0).over(w)).show()
+ +---+---+--------------+
+ | c1| c2|previous_value|
+ +---+---+--------------+
+ | a| 1| 0|
+ | a| 2| 1|
+ | a| 3| 2|
+ | b| 2| 0|
+ | b| 8| 2|
+ +---+---+--------------+
+ >>> df.withColumn("previous_value", lag("c2", 2, -1).over(w)).show()
+ +---+---+--------------+
+ | c1| c2|previous_value|
+ +---+---+--------------+
+ | a| 1| -1|
+ | a| 2| -1|
+ | a| 3| 1|
+ | b| 2| -1|
+ | b| 8| -1|
+ +---+---+--------------+
"""
from pyspark.sql.classic.column import _to_java_column
@@ -11241,13 +11288,27 @@ def sentences(
) -> Column:
"""
Splits a string into arrays of sentences, where each sentence is an array of words.
- The 'language' and 'country' arguments are optional, and if omitted, the default locale is used.
+ The `language` and `country` arguments are optional,
+ When they are omitted:
+ 1.If they are both omitted, the `Locale.ROOT - locale(language='', country='')` is used.
+ The `Locale.ROOT` is regarded as the base locale of all locales, and is used as the
+ language/country neutral locale for the locale sensitive operations.
+ 2.If the `country` is omitted, the `locale(language, country='')` is used.
+ When they are null:
+ 1.If they are both `null`, the `Locale.US - locale(language='en', country='US')` is used.
+ 2.If the `language` is null and the `country` is not null,
+ the `Locale.US - locale(language='en', country='US')` is used.
+ 3.If the `language` is not null and the `country` is null, the `locale(language)` is used.
+ 4.If neither is `null`, the `locale(language, country)` is used.
.. versionadded:: 3.2.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.0.0
+ Supports `sentences(string, language)`.
+
Parameters
----------
string : :class:`~pyspark.sql.Column` or str
@@ -11271,6 +11332,12 @@ def sentences(
+-----------------------------------+
|[[This, is, an, example, sentence]]|
+-----------------------------------+
+ >>> df.select(sentences(df.string, lit("en"))).show(truncate=False)
+ +-----------------------------------+
+ |sentences(string, en, ) |
+ +-----------------------------------+
+ |[[This, is, an, example, sentence]]|
+ +-----------------------------------+
>>> df = spark.createDataFrame([["Hello world. How are you?"]], ["s"])
>>> df.select(sentences("s")).show(truncate=False)
+---------------------------------+
@@ -11289,7 +11356,9 @@ def sentences(
@_try_remote_functions
def substring(
- str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int]
+ str: "ColumnOrName",
+ pos: Union["ColumnOrName", int],
+ len: Union["ColumnOrName", int],
) -> Column:
"""
Substring starts at `pos` and is of length `len` when str is String type or
@@ -11328,16 +11397,59 @@ def substring(
Examples
--------
+ Example 1: Using literal integers as arguments
+
+ >>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('abcd',)], ['s',])
- >>> df.select(substring(df.s, 1, 2).alias('s')).collect()
- [Row(s='ab')]
+ >>> df.select('*', sf.substring(df.s, 1, 2)).show()
+ +----+------------------+
+ | s|substring(s, 1, 2)|
+ +----+------------------+
+ |abcd| ab|
+ +----+------------------+
+
+ Example 2: Using columns as arguments
+
+ >>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
- >>> df.select(substring(df.s, 2, df.l).alias('s')).collect()
- [Row(s='par')]
- >>> df.select(substring(df.s, df.p, 3).alias('s')).collect()
- [Row(s='par')]
- >>> df.select(substring(df.s, df.p, df.l).alias('s')).collect()
- [Row(s='par')]
+ >>> df.select('*', sf.substring(df.s, 2, df.l)).show()
+ +-----+---+---+------------------+
+ | s| p| l|substring(s, 2, l)|
+ +-----+---+---+------------------+
+ |Spark| 2| 3| par|
+ +-----+---+---+------------------+
+
+ >>> df.select('*', sf.substring(df.s, df.p, 3)).show()
+ +-----+---+---+------------------+
+ | s| p| l|substring(s, p, 3)|
+ +-----+---+---+------------------+
+ |Spark| 2| 3| par|
+ +-----+---+---+------------------+
+
+ >>> df.select('*', sf.substring(df.s, df.p, df.l)).show()
+ +-----+---+---+------------------+
+ | s| p| l|substring(s, p, l)|
+ +-----+---+---+------------------+
+ |Spark| 2| 3| par|
+ +-----+---+---+------------------+
+
+ Example 3: Using column names as arguments
+
+ >>> import pyspark.sql.functions as sf
+ >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
+ >>> df.select('*', sf.substring(df.s, 2, 'l')).show()
+ +-----+---+---+------------------+
+ | s| p| l|substring(s, 2, l)|
+ +-----+---+---+------------------+
+ |Spark| 2| 3| par|
+ +-----+---+---+------------------+
+
+ >>> df.select('*', sf.substring('s', 'p', 'l')).show()
+ +-----+---+---+------------------+
+ | s| p| l|substring(s, p, l)|
+ +-----+---+---+------------------+
+ |Spark| 2| 3| par|
+ +-----+---+---+------------------+
"""
pos = _enum_to_value(pos)
pos = lit(pos) if isinstance(pos, int) else pos
@@ -16308,6 +16420,55 @@ def try_parse_json(
return _invoke_function("try_parse_json", _to_java_column(col))
+@_try_remote_functions
+def to_variant_object(
+ col: "ColumnOrName",
+) -> Column:
+ """
+ Converts a column containing nested inputs (array/map/struct) into a variants where maps and
+ structs are converted to variant objects which are unordered unlike SQL structs. Input maps can
+ only have string keys.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or str
+ a column with a nested schema or column name
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ a new column of VariantType.
+
+ Examples
+ --------
+ Example 1: Converting an array containing a nested struct into a variant
+
+ >>> from pyspark.sql import functions as sf
+ >>> from pyspark.sql.types import ArrayType, StructType, StructField, StringType, MapType
+ >>> schema = StructType([
+ ... StructField("i", StringType(), True),
+ ... StructField("v", ArrayType(StructType([
+ ... StructField("a", MapType(StringType(), StringType()), True)
+ ... ]), True))
+ ... ])
+ >>> data = [("1", [{"a": {"b": 2}}])]
+ >>> df = spark.createDataFrame(data, schema)
+ >>> df.select(sf.to_variant_object(df.v))
+ DataFrame[to_variant_object(v): variant]
+ >>> df.select(sf.to_variant_object(df.v)).show(truncate=False)
+ +--------------------+
+ |to_variant_object(v)|
+ +--------------------+
+ |[{"a":{"b":"2"}}] |
+ +--------------------+
+ """
+ from pyspark.sql.classic.column import _to_java_column
+
+ return _invoke_function("to_variant_object", _to_java_column(col))
+
+
@_try_remote_functions
def parse_json(
col: "ColumnOrName",
@@ -16467,7 +16628,7 @@ def schema_of_variant(v: "ColumnOrName") -> Column:
--------
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
>>> df.select(schema_of_variant(parse_json(df.json)).alias("r")).collect()
- [Row(r='STRUCT')]
+ [Row(r='OBJECT')]
"""
from pyspark.sql.classic.column import _to_java_column
@@ -16495,7 +16656,7 @@ def schema_of_variant_agg(v: "ColumnOrName") -> Column:
--------
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
>>> df.select(schema_of_variant_agg(parse_json(df.json)).alias("r")).collect()
- [Row(r='STRUCT')]
+ [Row(r='OBJECT')]
"""
from pyspark.sql.classic.column import _to_java_column
diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py
index 06834553ea96a..3173534c03c91 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -224,8 +224,6 @@ def applyInPandas(
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
- This API is experimental.
-
See Also
--------
pyspark.sql.functions.pandas_udf
@@ -329,8 +327,6 @@ def applyInPandasWithState(
Notes
-----
This function requires a full shuffle.
-
- This API is experimental.
"""
from pyspark.sql import GroupedData
@@ -484,8 +480,6 @@ def transformWithStateInPandas(
Notes
-----
This function requires a full shuffle.
-
- This API is experimental.
"""
from pyspark.sql import GroupedData
@@ -683,10 +677,6 @@ class PandasCogroupedOps:
.. versionchanged:: 3.4.0
Support Spark Connect.
-
- Notes
- -----
- This API is experimental.
"""
def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
@@ -778,8 +768,6 @@ def applyInPandas(
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
- This API is experimental.
-
See Also
--------
pyspark.sql.functions.pandas_udf
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 6203d4d19d866..076226865f3a7 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -510,8 +510,8 @@ def _create_batch(self, series):
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
- "A field of type StructType expects a pandas.DataFrame, "
- "but got: %s" % str(type(s))
+ "Invalid return type. Please make sure that the UDF returns a "
+ "pandas.DataFrame when the specified return type is StructType."
)
arrs.append(self._create_struct_array(s, t))
else:
diff --git a/python/pyspark/sql/plot/__init__.py b/python/pyspark/sql/plot/__init__.py
new file mode 100644
index 0000000000000..6da07061b2a09
--- /dev/null
+++ b/python/pyspark/sql/plot/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+This package includes the plotting APIs for PySpark DataFrame.
+"""
+from pyspark.sql.plot.core import * # noqa: F403, F401
diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py
new file mode 100644
index 0000000000000..392ef73b38845
--- /dev/null
+++ b/python/pyspark/sql/plot/core.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Any, TYPE_CHECKING, Optional, Union
+from types import ModuleType
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.sql.utils import require_minimum_plotly_version
+
+
+if TYPE_CHECKING:
+ from pyspark.sql import DataFrame
+ import pandas as pd
+ from plotly.graph_objs import Figure
+
+
+class PySparkTopNPlotBase:
+ def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":
+ from pyspark.sql import SparkSession
+
+ session = SparkSession.getActiveSession()
+ if session is None:
+ raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
+
+ max_rows = int(
+ session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
+ )
+ pdf = sdf.limit(max_rows + 1).toPandas()
+
+ self.partial = False
+ if len(pdf) > max_rows:
+ self.partial = True
+ pdf = pdf.iloc[:max_rows]
+
+ return pdf
+
+
+class PySparkSampledPlotBase:
+ def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
+ from pyspark.sql import SparkSession
+
+ session = SparkSession.getActiveSession()
+ if session is None:
+ raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())
+
+ sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio")
+ max_rows = int(
+ session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
+ )
+
+ if sample_ratio is None:
+ fraction = 1 / (sdf.count() / max_rows)
+ fraction = min(1.0, fraction)
+ else:
+ fraction = float(sample_ratio)
+
+ sampled_sdf = sdf.sample(fraction=fraction)
+ pdf = sampled_sdf.toPandas()
+
+ return pdf
+
+
+class PySparkPlotAccessor:
+ plot_data_map = {
+ "line": PySparkSampledPlotBase().get_sampled,
+ }
+ _backends = {} # type: ignore[var-annotated]
+
+ def __init__(self, data: "DataFrame"):
+ self.data = data
+
+ def __call__(
+ self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any
+ ) -> "Figure":
+ plot_backend = PySparkPlotAccessor._get_plot_backend(backend)
+
+ return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs)
+
+ @staticmethod
+ def _get_plot_backend(backend: Optional[str] = None) -> ModuleType:
+ backend = backend or "plotly"
+
+ if backend in PySparkPlotAccessor._backends:
+ return PySparkPlotAccessor._backends[backend]
+
+ if backend == "plotly":
+ require_minimum_plotly_version()
+ else:
+ raise PySparkValueError(
+ errorClass="UNSUPPORTED_PLOT_BACKEND",
+ messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])},
+ )
+ from pyspark.sql.plot import plotly as module
+
+ return module
+
+ def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
+ """
+ Plot DataFrame as lines.
+
+ Parameters
+ ----------
+ x : str
+ Name of column to use for the horizontal axis.
+ y : str or list of str
+ Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted.
+ **kwargs : optional
+ Additional keyword arguments.
+
+ Returns
+ -------
+ :class:`plotly.graph_objs.Figure`
+
+ Examples
+ --------
+ >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+ >>> columns = ["category", "int_val", "float_val"]
+ >>> df = spark.createDataFrame(data, columns)
+ >>> df.plot.line(x="category", y="int_val") # doctest: +SKIP
+ >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP
+ """
+ return self(kind="line", x=x, y=y, **kwargs)
diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py
new file mode 100644
index 0000000000000..5efc19476057f
--- /dev/null
+++ b/python/pyspark/sql/plot/plotly.py
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import TYPE_CHECKING, Any
+
+from pyspark.sql.plot import PySparkPlotAccessor
+
+if TYPE_CHECKING:
+ from pyspark.sql import DataFrame
+ from plotly.graph_objs import Figure
+
+
+def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
+ import plotly
+
+ return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)
diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py b/python/pyspark/sql/streaming/StateMessage_pb2.py
index 0f096e16d47ad..a22f004fd3048 100644
--- a/python/pyspark/sql/streaming/StateMessage_pb2.py
+++ b/python/pyspark/sql/streaming/StateMessage_pb2.py
@@ -16,14 +16,12 @@
#
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
-# NO CHECKED-IN PROTOBUF GENCODE
# source: StateMessage.proto
-# Protobuf Python Version: 5.27.1
"""Generated protocol buffer code."""
+from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
-from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
@@ -31,16 +29,17 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"5\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501
+ b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501
)
_globals = globals()
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals)
if not _descriptor._USE_C_DESCRIPTORS:
- DESCRIPTOR._loaded_options = None
- _globals["_HANDLESTATE"]._serialized_start = 1873
- _globals["_HANDLESTATE"]._serialized_end = 1948
+ DESCRIPTOR._options = None
+ _globals["_HANDLESTATE"]._serialized_start = 1978
+ _globals["_HANDLESTATE"]._serialized_end = 2053
_globals["_STATEREQUEST"]._serialized_start = 71
_globals["_STATEREQUEST"]._serialized_end = 432
_globals["_STATERESPONSE"]._serialized_start = 434
@@ -52,21 +51,23 @@
_globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1029
_globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1253
_globals["_STATECALLCOMMAND"]._serialized_start = 1255
- _globals["_STATECALLCOMMAND"]._serialized_end = 1308
- _globals["_VALUESTATECALL"]._serialized_start = 1311
- _globals["_VALUESTATECALL"]._serialized_end = 1664
- _globals["_SETIMPLICITKEY"]._serialized_start = 1666
- _globals["_SETIMPLICITKEY"]._serialized_end = 1695
- _globals["_REMOVEIMPLICITKEY"]._serialized_start = 1697
- _globals["_REMOVEIMPLICITKEY"]._serialized_end = 1716
- _globals["_EXISTS"]._serialized_start = 1718
- _globals["_EXISTS"]._serialized_end = 1726
- _globals["_GET"]._serialized_start = 1728
- _globals["_GET"]._serialized_end = 1733
- _globals["_VALUESTATEUPDATE"]._serialized_start = 1735
- _globals["_VALUESTATEUPDATE"]._serialized_end = 1768
- _globals["_CLEAR"]._serialized_start = 1770
- _globals["_CLEAR"]._serialized_end = 1777
- _globals["_SETHANDLESTATE"]._serialized_start = 1779
- _globals["_SETHANDLESTATE"]._serialized_end = 1871
+ _globals["_STATECALLCOMMAND"]._serialized_end = 1380
+ _globals["_VALUESTATECALL"]._serialized_start = 1383
+ _globals["_VALUESTATECALL"]._serialized_end = 1736
+ _globals["_SETIMPLICITKEY"]._serialized_start = 1738
+ _globals["_SETIMPLICITKEY"]._serialized_end = 1767
+ _globals["_REMOVEIMPLICITKEY"]._serialized_start = 1769
+ _globals["_REMOVEIMPLICITKEY"]._serialized_end = 1788
+ _globals["_EXISTS"]._serialized_start = 1790
+ _globals["_EXISTS"]._serialized_end = 1798
+ _globals["_GET"]._serialized_start = 1800
+ _globals["_GET"]._serialized_end = 1805
+ _globals["_VALUESTATEUPDATE"]._serialized_start = 1807
+ _globals["_VALUESTATEUPDATE"]._serialized_end = 1840
+ _globals["_CLEAR"]._serialized_start = 1842
+ _globals["_CLEAR"]._serialized_end = 1849
+ _globals["_SETHANDLESTATE"]._serialized_start = 1851
+ _globals["_SETHANDLESTATE"]._serialized_end = 1943
+ _globals["_TTLCONFIG"]._serialized_start = 1945
+ _globals["_TTLCONFIG"]._serialized_end = 1976
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/StateMessage_pb2.pyi
index 0e6f1fb065881..1ab48a27c8f87 100644
--- a/python/pyspark/sql/streaming/StateMessage_pb2.pyi
+++ b/python/pyspark/sql/streaming/StateMessage_pb2.pyi
@@ -13,163 +13,167 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+#
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
-from typing import (
- ClassVar as _ClassVar,
- Mapping as _Mapping,
- Optional as _Optional,
- Union as _Union,
-)
+from typing import ClassVar, Mapping, Optional, Union
+CLOSED: HandleState
+CREATED: HandleState
+DATA_PROCESSED: HandleState
DESCRIPTOR: _descriptor.FileDescriptor
+INITIALIZED: HandleState
-class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
+class Clear(_message.Message):
__slots__ = ()
- CREATED: _ClassVar[HandleState]
- INITIALIZED: _ClassVar[HandleState]
- DATA_PROCESSED: _ClassVar[HandleState]
- CLOSED: _ClassVar[HandleState]
+ def __init__(self) -> None: ...
-CREATED: HandleState
-INITIALIZED: HandleState
-DATA_PROCESSED: HandleState
-CLOSED: HandleState
+class Exists(_message.Message):
+ __slots__ = ()
+ def __init__(self) -> None: ...
+
+class Get(_message.Message):
+ __slots__ = ()
+ def __init__(self) -> None: ...
+
+class ImplicitGroupingKeyRequest(_message.Message):
+ __slots__ = ["removeImplicitKey", "setImplicitKey"]
+ REMOVEIMPLICITKEY_FIELD_NUMBER: ClassVar[int]
+ SETIMPLICITKEY_FIELD_NUMBER: ClassVar[int]
+ removeImplicitKey: RemoveImplicitKey
+ setImplicitKey: SetImplicitKey
+ def __init__(
+ self,
+ setImplicitKey: Optional[Union[SetImplicitKey, Mapping]] = ...,
+ removeImplicitKey: Optional[Union[RemoveImplicitKey, Mapping]] = ...,
+ ) -> None: ...
+
+class RemoveImplicitKey(_message.Message):
+ __slots__ = ()
+ def __init__(self) -> None: ...
+
+class SetHandleState(_message.Message):
+ __slots__ = ["state"]
+ STATE_FIELD_NUMBER: ClassVar[int]
+ state: HandleState
+ def __init__(self, state: Optional[Union[HandleState, str]] = ...) -> None: ...
+
+class SetImplicitKey(_message.Message):
+ __slots__ = ["key"]
+ KEY_FIELD_NUMBER: ClassVar[int]
+ key: bytes
+ def __init__(self, key: Optional[bytes] = ...) -> None: ...
+
+class StateCallCommand(_message.Message):
+ __slots__ = ["schema", "stateName", "ttl"]
+ SCHEMA_FIELD_NUMBER: ClassVar[int]
+ STATENAME_FIELD_NUMBER: ClassVar[int]
+ TTL_FIELD_NUMBER: ClassVar[int]
+ schema: str
+ stateName: str
+ ttl: TTLConfig
+ def __init__(
+ self,
+ stateName: Optional[str] = ...,
+ schema: Optional[str] = ...,
+ ttl: Optional[Union[TTLConfig, Mapping]] = ...,
+ ) -> None: ...
class StateRequest(_message.Message):
- __slots__ = (
- "version",
- "statefulProcessorCall",
- "stateVariableRequest",
+ __slots__ = [
"implicitGroupingKeyRequest",
- )
- VERSION_FIELD_NUMBER: _ClassVar[int]
- STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int]
- STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int]
- IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int]
- version: int
- statefulProcessorCall: StatefulProcessorCall
- stateVariableRequest: StateVariableRequest
+ "stateVariableRequest",
+ "statefulProcessorCall",
+ "version",
+ ]
+ IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: ClassVar[int]
+ STATEFULPROCESSORCALL_FIELD_NUMBER: ClassVar[int]
+ STATEVARIABLEREQUEST_FIELD_NUMBER: ClassVar[int]
+ VERSION_FIELD_NUMBER: ClassVar[int]
implicitGroupingKeyRequest: ImplicitGroupingKeyRequest
+ stateVariableRequest: StateVariableRequest
+ statefulProcessorCall: StatefulProcessorCall
+ version: int
def __init__(
self,
- version: _Optional[int] = ...,
- statefulProcessorCall: _Optional[_Union[StatefulProcessorCall, _Mapping]] = ...,
- stateVariableRequest: _Optional[_Union[StateVariableRequest, _Mapping]] = ...,
- implicitGroupingKeyRequest: _Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ...,
+ version: Optional[int] = ...,
+ statefulProcessorCall: Optional[Union[StatefulProcessorCall, Mapping]] = ...,
+ stateVariableRequest: Optional[Union[StateVariableRequest, Mapping]] = ...,
+ implicitGroupingKeyRequest: Optional[Union[ImplicitGroupingKeyRequest, Mapping]] = ...,
) -> None: ...
class StateResponse(_message.Message):
- __slots__ = ("statusCode", "errorMessage", "value")
- STATUSCODE_FIELD_NUMBER: _ClassVar[int]
- ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int]
- VALUE_FIELD_NUMBER: _ClassVar[int]
- statusCode: int
+ __slots__ = ["errorMessage", "statusCode", "value"]
+ ERRORMESSAGE_FIELD_NUMBER: ClassVar[int]
+ STATUSCODE_FIELD_NUMBER: ClassVar[int]
+ VALUE_FIELD_NUMBER: ClassVar[int]
errorMessage: str
+ statusCode: int
value: bytes
- def __init__(
- self, statusCode: _Optional[int] = ..., errorMessage: _Optional[str] = ...
- ) -> None: ...
-
-class StatefulProcessorCall(_message.Message):
- __slots__ = ("setHandleState", "getValueState", "getListState", "getMapState")
- SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int]
- GETVALUESTATE_FIELD_NUMBER: _ClassVar[int]
- GETLISTSTATE_FIELD_NUMBER: _ClassVar[int]
- GETMAPSTATE_FIELD_NUMBER: _ClassVar[int]
- setHandleState: SetHandleState
- getValueState: StateCallCommand
- getListState: StateCallCommand
- getMapState: StateCallCommand
def __init__(
self,
- setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ...,
- getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
- getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
- getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
+ statusCode: Optional[int] = ...,
+ errorMessage: Optional[str] = ...,
+ value: Optional[bytes] = ...,
) -> None: ...
class StateVariableRequest(_message.Message):
- __slots__ = ("valueStateCall",)
- VALUESTATECALL_FIELD_NUMBER: _ClassVar[int]
+ __slots__ = ["valueStateCall"]
+ VALUESTATECALL_FIELD_NUMBER: ClassVar[int]
valueStateCall: ValueStateCall
- def __init__(
- self, valueStateCall: _Optional[_Union[ValueStateCall, _Mapping]] = ...
- ) -> None: ...
+ def __init__(self, valueStateCall: Optional[Union[ValueStateCall, Mapping]] = ...) -> None: ...
-class ImplicitGroupingKeyRequest(_message.Message):
- __slots__ = ("setImplicitKey", "removeImplicitKey")
- SETIMPLICITKEY_FIELD_NUMBER: _ClassVar[int]
- REMOVEIMPLICITKEY_FIELD_NUMBER: _ClassVar[int]
- setImplicitKey: SetImplicitKey
- removeImplicitKey: RemoveImplicitKey
+class StatefulProcessorCall(_message.Message):
+ __slots__ = ["getListState", "getMapState", "getValueState", "setHandleState"]
+ GETLISTSTATE_FIELD_NUMBER: ClassVar[int]
+ GETMAPSTATE_FIELD_NUMBER: ClassVar[int]
+ GETVALUESTATE_FIELD_NUMBER: ClassVar[int]
+ SETHANDLESTATE_FIELD_NUMBER: ClassVar[int]
+ getListState: StateCallCommand
+ getMapState: StateCallCommand
+ getValueState: StateCallCommand
+ setHandleState: SetHandleState
def __init__(
self,
- setImplicitKey: _Optional[_Union[SetImplicitKey, _Mapping]] = ...,
- removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] = ...,
+ setHandleState: Optional[Union[SetHandleState, Mapping]] = ...,
+ getValueState: Optional[Union[StateCallCommand, Mapping]] = ...,
+ getListState: Optional[Union[StateCallCommand, Mapping]] = ...,
+ getMapState: Optional[Union[StateCallCommand, Mapping]] = ...,
) -> None: ...
-class StateCallCommand(_message.Message):
- __slots__ = ("stateName", "schema")
- STATENAME_FIELD_NUMBER: _ClassVar[int]
- SCHEMA_FIELD_NUMBER: _ClassVar[int]
- stateName: str
- schema: str
- def __init__(self, stateName: _Optional[str] = ..., schema: _Optional[str] = ...) -> None: ...
+class TTLConfig(_message.Message):
+ __slots__ = ["durationMs"]
+ DURATIONMS_FIELD_NUMBER: ClassVar[int]
+ durationMs: int
+ def __init__(self, durationMs: Optional[int] = ...) -> None: ...
class ValueStateCall(_message.Message):
- __slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear")
- STATENAME_FIELD_NUMBER: _ClassVar[int]
- EXISTS_FIELD_NUMBER: _ClassVar[int]
- GET_FIELD_NUMBER: _ClassVar[int]
- VALUESTATEUPDATE_FIELD_NUMBER: _ClassVar[int]
- CLEAR_FIELD_NUMBER: _ClassVar[int]
- stateName: str
+ __slots__ = ["clear", "exists", "get", "stateName", "valueStateUpdate"]
+ CLEAR_FIELD_NUMBER: ClassVar[int]
+ EXISTS_FIELD_NUMBER: ClassVar[int]
+ GET_FIELD_NUMBER: ClassVar[int]
+ STATENAME_FIELD_NUMBER: ClassVar[int]
+ VALUESTATEUPDATE_FIELD_NUMBER: ClassVar[int]
+ clear: Clear
exists: Exists
get: Get
+ stateName: str
valueStateUpdate: ValueStateUpdate
- clear: Clear
def __init__(
self,
- stateName: _Optional[str] = ...,
- exists: _Optional[_Union[Exists, _Mapping]] = ...,
- get: _Optional[_Union[Get, _Mapping]] = ...,
- valueStateUpdate: _Optional[_Union[ValueStateUpdate, _Mapping]] = ...,
- clear: _Optional[_Union[Clear, _Mapping]] = ...,
+ stateName: Optional[str] = ...,
+ exists: Optional[Union[Exists, Mapping]] = ...,
+ get: Optional[Union[Get, Mapping]] = ...,
+ valueStateUpdate: Optional[Union[ValueStateUpdate, Mapping]] = ...,
+ clear: Optional[Union[Clear, Mapping]] = ...,
) -> None: ...
-class SetImplicitKey(_message.Message):
- __slots__ = ("key",)
- KEY_FIELD_NUMBER: _ClassVar[int]
- key: bytes
- def __init__(self, key: _Optional[bytes] = ...) -> None: ...
-
-class RemoveImplicitKey(_message.Message):
- __slots__ = ()
- def __init__(self) -> None: ...
-
-class Exists(_message.Message):
- __slots__ = ()
- def __init__(self) -> None: ...
-
-class Get(_message.Message):
- __slots__ = ()
- def __init__(self) -> None: ...
-
class ValueStateUpdate(_message.Message):
- __slots__ = ("value",)
- VALUE_FIELD_NUMBER: _ClassVar[int]
+ __slots__ = ["value"]
+ VALUE_FIELD_NUMBER: ClassVar[int]
value: bytes
- def __init__(self, value: _Optional[bytes] = ...) -> None: ...
+ def __init__(self, value: Optional[bytes] = ...) -> None: ...
-class Clear(_message.Message):
+class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
- def __init__(self) -> None: ...
-
-class SetHandleState(_message.Message):
- __slots__ = ("state",)
- STATE_FIELD_NUMBER: _ClassVar[int]
- state: HandleState
- def __init__(self, state: _Optional[_Union[HandleState, str]] = ...) -> None: ...
diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py
index a378eec2b6175..9045c81e287cd 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -88,7 +88,9 @@ class StatefulProcessorHandle:
def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None:
self.stateful_processor_api_client = stateful_processor_api_client
- def getValueState(self, state_name: str, schema: Union[StructType, str]) -> ValueState:
+ def getValueState(
+ self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None
+ ) -> ValueState:
"""
Function to create new or return existing single value state variable of given type.
The user must ensure to call this function only within the `init()` method of the
@@ -101,8 +103,13 @@ def getValueState(self, state_name: str, schema: Union[StructType, str]) -> Valu
schema : :class:`pyspark.sql.types.DataType` or str
The schema of the state variable. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+ ttlDurationMs: int
+ Time to live duration of the state in milliseconds. State values will not be returned
+ past ttlDuration and will be eventually removed from the state store. Any state update
+ resets the expiration time to current processing time plus ttlDuration.
+ If ttl is not specified the state will never expire.
"""
- self.stateful_processor_api_client.get_value_state(state_name, schema)
+ self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms)
return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema)
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 080d7739992ec..9703aa17d3474 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -17,7 +17,7 @@
from enum import Enum
import os
import socket
-from typing import Any, Union, cast, Tuple
+from typing import Any, Union, Optional, cast, Tuple
from pyspark.serializers import write_int, read_int, UTF8Deserializer
from pyspark.sql.types import StructType, _parse_datatype_string, Row
@@ -101,7 +101,9 @@ def remove_implicit_key(self) -> None:
# TODO(SPARK-49233): Classify errors thrown by internal methods.
raise PySparkRuntimeError(f"Error removing implicit key: " f"{response_message[1]}")
- def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> None:
+ def get_value_state(
+ self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int]
+ ) -> None:
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
@@ -110,6 +112,8 @@ def get_value_state(self, state_name: str, schema: Union[StructType, str]) -> No
state_call_command = stateMessage.StateCallCommand()
state_call_command.stateName = state_name
state_call_command.schema = schema.json()
+ if ttl_duration_ms is not None:
+ state_call_command.ttl.durationMs = ttl_duration_ms
call = stateMessage.StatefulProcessorCall(getValueState=state_call_command)
message = stateMessage.StateRequest(statefulProcessorCall=call)
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py
index 196c9eb5d81d8..5deb73a0ccf90 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -417,6 +417,18 @@ def checks():
for b in parameters:
not_found_fails(b)
+ def test_observed_session_id(self):
+ stub = self._stub_with([self.response, self.finished])
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, [])
+ session_id = "test-session-id"
+
+ reattach = ite._create_reattach_execute_request()
+ self.assertEqual(reattach.client_observed_server_side_session_id, "")
+
+ self.request.client_observed_server_side_session_id = session_id
+ reattach = ite._create_reattach_execute_request()
+ self.assertEqual(reattach.client_observed_server_side_session_id, session_id)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_client import * # noqa: F401
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index c3d0d28017e60..f05f982d2d14e 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -176,7 +176,7 @@ def test_slow_query(self):
def test_listener_throw(self):
"""
- Following Vanilla Spark's behavior, when the callback of user-defined listener throws,
+ Following classic Spark's behavior, when the callback of user-defined listener throws,
other listeners should still proceed.
"""
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index 51ce1cd685210..e29873173cc3a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -2572,7 +2572,7 @@ def test_function_parity(self):
cf_fn = {name for (name, value) in getmembers(CF, isfunction) if name[0] != "_"}
- # Functions in vanilla PySpark we do not expect to be available in Spark Connect
+ # Functions in classic PySpark we do not expect to be available in Spark Connect
sf_excluded_fn = set()
self.assertEqual(
@@ -2581,7 +2581,7 @@ def test_function_parity(self):
"Missing functions in Spark Connect not as expected",
)
- # Functions in Spark Connect we do not expect to be available in vanilla PySpark
+ # Functions in Spark Connect we do not expect to be available in classic PySpark
cf_excluded_fn = {
"check_dependencies", # internal helper function
}
@@ -2589,7 +2589,7 @@ def test_function_parity(self):
self.assertEqual(
cf_fn - sf_fn,
cf_excluded_fn,
- "Missing functions in vanilla PySpark not as expected",
+ "Missing functions in classic PySpark not as expected",
)
# SPARK-45216: Fix non-deterministic seeded Dataset APIs
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
new file mode 100644
index 0000000000000..c69e438bf7eb0
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin
+
+
+class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_frame_plot import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
new file mode 100644
index 0000000000000..78508fe533379
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py
@@ -0,0 +1,36 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.tests.plot.test_frame_plot_plotly import DataFramePlotPlotlyTestsMixin
+
+
+class FramePlotPlotlyParityTests(DataFramePlotPlotlyTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_frame_plot_plotly import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index f05a601094a5d..8ad24704de3a4 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -16,6 +16,7 @@
#
import os
+import time
import tempfile
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from typing import Iterator
@@ -32,6 +33,7 @@
Row,
IntegerType,
)
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
@@ -59,20 +61,18 @@ def conf(cls):
)
return cfg
+ def _prepare_input_data(self, input_path, col1, col2):
+ with open(input_path, "w") as fw:
+ for e1, e2 in zip(col1, col2):
+ fw.write(f"{e1}, {e2}\n")
+
def _prepare_test_resource1(self, input_path):
- with open(input_path + "/text-test1.txt", "w") as fw:
- fw.write("0, 123\n")
- fw.write("0, 46\n")
- fw.write("1, 146\n")
- fw.write("1, 346\n")
+ self._prepare_input_data(input_path + "/text-test1.txt", [0, 0, 1, 1], [123, 46, 146, 346])
def _prepare_test_resource2(self, input_path):
- with open(input_path + "/text-test2.txt", "w") as fw:
- fw.write("0, 123\n")
- fw.write("0, 223\n")
- fw.write("0, 323\n")
- fw.write("1, 246\n")
- fw.write("1, 6\n")
+ self._prepare_input_data(
+ input_path + "/text-test2.txt", [0, 0, 0, 1, 1], [123, 223, 323, 246, 6]
+ )
def _build_test_df(self, input_path):
df = self.spark.readStream.format("text").option("maxFilesPerTrigger", 1).load(input_path)
@@ -84,7 +84,7 @@ def _build_test_df(self, input_path):
return df_final
def _test_transform_with_state_in_pandas_basic(
- self, stateful_processor, check_results, single_batch=False
+ self, stateful_processor, check_results, single_batch=False, timeMode="None"
):
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)
@@ -110,7 +110,7 @@ def _test_transform_with_state_in_pandas_basic(
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
- timeMode="None",
+ timeMode=timeMode,
)
.writeStream.queryName("this_query")
.foreachBatch(check_results)
@@ -211,6 +211,99 @@ def test_transform_with_state_in_pandas_query_restarts(self):
Row(id="1", countAsString="2"),
}
+ # test value state with ttl has the same behavior as value state when
+ # state doesn't expire.
+ def test_value_state_ttl_basic(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+ else:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="1", countAsString="2"),
+ }
+
+ self._test_transform_with_state_in_pandas_basic(
+ SimpleTTLStatefulProcessor(), check_results, False, "processingTime"
+ )
+
+ def test_value_state_ttl_expiration(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assertDataFrameEqual(
+ batch_df,
+ [
+ Row(id="ttl-count-0", count=1),
+ Row(id="count-0", count=1),
+ Row(id="ttl-count-1", count=1),
+ Row(id="count-1", count=1),
+ ],
+ )
+ elif batch_id == 1:
+ assertDataFrameEqual(
+ batch_df,
+ [
+ Row(id="ttl-count-0", count=2),
+ Row(id="count-0", count=2),
+ Row(id="ttl-count-1", count=2),
+ Row(id="count-1", count=2),
+ ],
+ )
+ elif batch_id == 2:
+ # ttl-count-0 expire and restart from count 0.
+ # ttl-count-1 get reset in batch 1 and keep the state
+ # non-ttl state never expires
+ assertDataFrameEqual(
+ batch_df,
+ [
+ Row(id="ttl-count-0", count=1),
+ Row(id="count-0", count=3),
+ Row(id="ttl-count-1", count=3),
+ Row(id="count-1", count=3),
+ ],
+ )
+ if batch_id == 0 or batch_id == 1:
+ time.sleep(6)
+
+ input_dir = tempfile.TemporaryDirectory()
+ input_path = input_dir.name
+ try:
+ df = self._build_test_df(input_path)
+ self._prepare_input_data(input_path + "/batch1.txt", [1, 0], [0, 0])
+ self._prepare_input_data(input_path + "/batch2.txt", [1, 0], [0, 0])
+ self._prepare_input_data(input_path + "/batch3.txt", [1, 0], [0, 0])
+ for q in self.spark.streams.active:
+ q.stop()
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("count", IntegerType(), True),
+ ]
+ )
+
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=TTLStatefulProcessor(),
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="processingTime",
+ )
+ .writeStream.foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.stop()
+ q.awaitTermination()
+ self.assertTrue(q.exception() is None)
+ finally:
+ input_dir.cleanup()
+
class SimpleStatefulProcessor(StatefulProcessor):
dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
@@ -246,6 +339,43 @@ def close(self) -> None:
pass
+# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use
+# ttl state with a large timeout.
+class SimpleTTLStatefulProcessor(SimpleStatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000)
+
+
+class TTLStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", IntegerType(), True)])
+ self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000)
+ self.count_state = handle.getValueState("state", state_schema)
+
+ def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+ count = 0
+ ttl_count = 0
+ id = key[0]
+ if self.count_state.exists():
+ count = self.count_state.get()[0]
+ if self.ttl_count_state.exists():
+ ttl_count = self.ttl_count_state.get()[0]
+ for pdf in rows:
+ pdf_count = pdf.count().get("temperature")
+ count += pdf_count
+ ttl_count += pdf_count
+
+ self.count_state.update((count,))
+ # skip updating state for the 2nd batch so that ttl state expire
+ if not (ttl_count == 2 and id == "0"):
+ self.ttl_count_state.update((ttl_count,))
+ yield pd.DataFrame({"id": [f"ttl-count-{id}", f"count-{id}"], "count": [ttl_count, count]})
+
+ def close(self) -> None:
+ pass
+
+
class InvalidSimpleStatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 6720dfc37d0cc..228fc30b497cc 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -339,6 +339,19 @@ def noop(s: pd.Series) -> pd.Series:
self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second")
self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))
+ def test_pandas_udf_return_type_error(self):
+ import pandas as pd
+
+ @pandas_udf("s string")
+ def upper(s: pd.Series) -> pd.Series:
+ return s.str.upper()
+
+ df = self.spark.createDataFrame([("a",)], schema="s string")
+
+ self.assertRaisesRegex(
+ PythonException, "Invalid return type", df.select(upper("s")).collect
+ )
+
class PandasUDFTests(PandasUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/plot/__init__.py b/python/pyspark/sql/tests/plot/__init__.py
new file mode 100644
index 0000000000000..cce3acad34a49
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py
new file mode 100644
index 0000000000000..f753b5ab3db72
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from pyspark.errors import PySparkValueError
+from pyspark.sql import Row
+from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
+
+
+@unittest.skipIf(not have_plotly, plotly_requirement_message)
+class DataFramePlotTestsMixin:
+ def test_backend(self):
+ accessor = self.spark.range(2).plot
+ backend = accessor._get_plot_backend()
+ self.assertEqual(backend.__name__, "pyspark.sql.plot.plotly")
+
+ with self.assertRaises(PySparkValueError) as pe:
+ accessor._get_plot_backend("matplotlib")
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="UNSUPPORTED_PLOT_BACKEND",
+ messageParameters={"backend": "matplotlib", "supported_backends": "plotly"},
+ )
+
+ def test_topn_max_rows(self):
+ try:
+ self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000")
+ sdf = self.spark.range(2500)
+ pdf = PySparkTopNPlotBase().get_top_n(sdf)
+ self.assertEqual(len(pdf), 1000)
+ finally:
+ self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows")
+
+ def test_sampled_plot_with_ratio(self):
+ try:
+ self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", "0.5")
+ data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)]
+ sdf = self.spark.createDataFrame(data)
+ pdf = PySparkSampledPlotBase().get_sampled(sdf)
+ self.assertEqual(round(len(pdf) / 2500, 1), 0.5)
+ finally:
+ self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio")
+
+ def test_sampled_plot_with_max_rows(self):
+ data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)]
+ sdf = self.spark.createDataFrame(data)
+ pdf = PySparkSampledPlotBase().get_sampled(sdf)
+ self.assertEqual(round(len(pdf) / 2000, 1), 0.5)
+
+
+class DataFramePlotTests(DataFramePlotTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.plot.test_frame_plot import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
new file mode 100644
index 0000000000000..72a3ed267d192
--- /dev/null
+++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import pyspark.sql.plot # noqa: F401
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
+
+
+@unittest.skipIf(not have_plotly, plotly_requirement_message)
+class DataFramePlotPlotlyTestsMixin:
+ @property
+ def sdf(self):
+ data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)]
+ columns = ["category", "int_val", "float_val"]
+ return self.spark.createDataFrame(data, columns)
+
+ def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""):
+ self.assertEqual(fig_data["mode"], "lines")
+ self.assertEqual(fig_data["type"], "scatter")
+ self.assertEqual(fig_data["xaxis"], "x")
+ self.assertEqual(list(fig_data["x"]), expected_x)
+ self.assertEqual(fig_data["yaxis"], "y")
+ self.assertEqual(list(fig_data["y"]), expected_y)
+ self.assertEqual(fig_data["name"], expected_name)
+
+ def test_line_plot(self):
+ # single column as vertical axis
+ fig = self.sdf.plot(kind="line", x="category", y="int_val")
+ self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20])
+
+ # multiple columns as vertical axis
+ fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
+ self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
+ self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
+
+
+class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.plot.test_frame_plot_plotly import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
index 5041fefff1909..b29338e7f59e7 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
@@ -220,9 +220,10 @@ def close(self, error):
try:
tester.run_streaming_query_on_writer(ForeachWriter(), 1)
self.fail("bad writer did not fail the query") # this is not expected
- except StreamingQueryException:
- # TODO: Verify whether original error message is inside the exception
- pass
+ except StreamingQueryException as e:
+ err_msg = str(e)
+ self.assertTrue("test error" in err_msg)
+ self.assertTrue("FOREACH_USER_FUNCTION_ERROR" in err_msg)
self.assertEqual(len(tester.process_events()), 0) # no row was processed
close_events = tester.close_events()
diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py
index 2bd66baaa2bfe..1972dd2804d98 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -18,11 +18,14 @@
from enum import Enum
from itertools import chain
+import datetime
+import unittest
+
from pyspark.sql import Column, Row
from pyspark.sql import functions as sf
from pyspark.sql.types import StructType, StructField, IntegerType, LongType
from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, pandas_requirement_message
class ColumnTestsMixin:
@@ -280,6 +283,33 @@ def test_expr_str_representation(self):
when_cond = sf.when(expression, sf.lit(None))
self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>")
+ def test_lit_time_representation(self):
+ dt = datetime.date(2021, 3, 4)
+ self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>")
+
+ ts = datetime.datetime(2021, 3, 4, 12, 34, 56, 1234)
+ self.assertEqual(str(sf.lit(ts)), "Column<'2021-03-04 12:34:56.001234'>")
+
+ @unittest.skipIf(not have_pandas, pandas_requirement_message)
+ def test_lit_delta_representation(self):
+ for delta in [
+ datetime.timedelta(days=1),
+ datetime.timedelta(hours=2),
+ datetime.timedelta(minutes=3),
+ datetime.timedelta(seconds=4),
+ datetime.timedelta(microseconds=5),
+ datetime.timedelta(days=2, hours=21, microseconds=908),
+ datetime.timedelta(days=1, minutes=-3, microseconds=-1001),
+ datetime.timedelta(days=1, hours=2, minutes=3, seconds=4, microseconds=5),
+ ]:
+ import pandas as pd
+
+ # Column<'PT69H0.000908S'> or Column<'P2DT21H0M0.000908S'>
+ s = str(sf.lit(delta))
+
+ # Parse the ISO string representation and compare
+ self.assertTrue(pd.Timedelta(s[8:-2]).to_pytimedelta() == delta)
+
def test_enum_literals(self):
class IntEnum(Enum):
X = 1
diff --git a/python/pyspark/sql/tests/test_creation.py b/python/pyspark/sql/tests/test_creation.py
index dfe66cdd3edf0..c6917aa234b41 100644
--- a/python/pyspark/sql/tests/test_creation.py
+++ b/python/pyspark/sql/tests/test_creation.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import platform
from decimal import Decimal
import os
import time
@@ -111,11 +110,7 @@ def test_create_dataframe_from_pandas_with_dst(self):
os.environ["TZ"] = orig_env_tz
time.tzset()
- # TODO(SPARK-43354): Re-enable test_create_dataframe_from_pandas_with_day_time_interval
- @unittest.skipIf(
- "pypy" in platform.python_implementation().lower() or not have_pandas,
- "Fails in PyPy Python 3.8, should enable.",
- )
+ @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_create_dataframe_from_pandas_with_day_time_interval(self):
# SPARK-37277: Test DayTimeIntervalType in createDataFrame without Arrow.
import pandas as pd
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index a214b874f5ec0..8ec0839ec1fe4 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -39,6 +39,7 @@
PySparkTypeError,
PySparkValueError,
)
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pyarrow,
@@ -955,6 +956,74 @@ def test_checkpoint_dataframe(self):
self.spark.range(1).localCheckpoint().explain()
self.assertIn("ExistingRDD", buf.getvalue())
+ def test_transpose(self):
+ df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}])
+
+ # default index column
+ transposed_df = df.transpose()
+ expected_schema = StructType(
+ [StructField("key", StringType(), False), StructField("x", StringType(), True)]
+ )
+ expected_data = [Row(key="b", x="y"), Row(key="c", x="z")]
+ expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema)
+ assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True)
+
+ # specified index column
+ transposed_df = df.transpose("c")
+ expected_schema = StructType(
+ [StructField("key", StringType(), False), StructField("z", StringType(), True)]
+ )
+ expected_data = [Row(key="a", z="x"), Row(key="b", z="y")]
+ expected_df = self.spark.createDataFrame(expected_data, schema=expected_schema)
+ assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True)
+
+ # enforce transpose max values
+ with self.sql_conf({"spark.sql.transposeMaxValues": 0}):
+ with self.assertRaises(AnalysisException) as pe:
+ df.transpose().collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="TRANSPOSE_EXCEED_ROW_LIMIT",
+ messageParameters={"maxValues": "0", "config": "spark.sql.transposeMaxValues"},
+ )
+
+ # enforce ascending order based on index column values for transposed columns
+ df = self.spark.createDataFrame([{"a": "z"}, {"a": "y"}, {"a": "x"}])
+ transposed_df = df.transpose()
+ expected_schema = StructType(
+ [
+ StructField("key", StringType(), False),
+ StructField("x", StringType(), True),
+ StructField("y", StringType(), True),
+ StructField("z", StringType(), True),
+ ]
+ ) # z, y, x -> x, y, z
+ expected_df = self.spark.createDataFrame([], schema=expected_schema)
+ assertDataFrameEqual(transposed_df, expected_df, checkRowOrder=True)
+
+ # enforce AtomicType Attribute for index column values
+ df = self.spark.createDataFrame([{"a": ["x", "x"], "b": "y", "c": "z"}])
+ with self.assertRaises(AnalysisException) as pe:
+ df.transpose().collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="TRANSPOSE_INVALID_INDEX_COLUMN",
+ messageParameters={
+ "reason": "Index column must be of atomic type, "
+ "but found: ArrayType(StringType,true)"
+ },
+ )
+
+ # enforce least common type for non-index columns
+ df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": 1}])
+ with self.assertRaises(AnalysisException) as pe:
+ df.transpose().collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="TRANSPOSE_NO_LEAST_COMMON_TYPE",
+ messageParameters={"dt1": "STRING", "dt2": "BIGINT"},
+ )
+
class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
def test_query_execution_unsupported_in_classic(self):
diff --git a/python/pyspark/sql/tests/test_dataframe_query_context.py b/python/pyspark/sql/tests/test_dataframe_query_context.py
index 3f31f1d62d73d..bf0cc021ca771 100644
--- a/python/pyspark/sql/tests/test_dataframe_query_context.py
+++ b/python/pyspark/sql/tests/test_dataframe_query_context.py
@@ -54,7 +54,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
@@ -70,7 +69,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
@@ -86,7 +84,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
@@ -102,7 +99,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mod__",
@@ -118,7 +114,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__eq__",
@@ -134,7 +129,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__lt__",
@@ -150,7 +144,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__le__",
@@ -166,7 +159,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__ge__",
@@ -182,7 +174,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__gt__",
@@ -198,7 +189,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="eqNullSafe",
@@ -214,7 +204,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="bitwiseOR",
@@ -230,7 +219,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="bitwiseAND",
@@ -246,7 +234,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="bitwiseXOR",
@@ -279,7 +266,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
@@ -299,7 +285,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
@@ -317,7 +302,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
@@ -344,7 +328,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
@@ -360,7 +343,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
@@ -376,7 +358,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
@@ -407,7 +388,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__add__",
@@ -425,7 +405,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__sub__",
@@ -443,7 +422,6 @@ def test_dataframe_query_context(self):
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
- "ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
fragment="__mul__",
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index f7f2485a43e16..a0ab9bc9c7d40 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1326,8 +1326,8 @@ def check(resultDf, expected):
self.assertEqual([r[0] for r in resultDf.collect()], expected)
check(df.select(F.is_variant_null(v)), [False, False])
- check(df.select(F.schema_of_variant(v)), ["STRUCT", "STRUCT"])
- check(df.select(F.schema_of_variant_agg(v)), ["STRUCT"])
+ check(df.select(F.schema_of_variant(v)), ["OBJECT", "OBJECT"])
+ check(df.select(F.schema_of_variant_agg(v)), ["OBJECT"])
check(df.select(F.variant_get(v, "$.a", "int")), [1, None])
check(df.select(F.variant_get(v, "$.b", "int")), [None, 2])
@@ -1365,6 +1365,13 @@ def test_try_parse_json(self):
self.assertEqual("""{"a":1}""", actual[0]["var"])
self.assertEqual(None, actual[1]["var"])
+ def test_to_variant_object(self):
+ df = self.spark.createDataFrame([(1, {"a": 1})], "i int, v struct")
+ actual = df.select(
+ F.to_json(F.to_variant_object(df.v)).alias("var"),
+ ).collect()
+ self.assertEqual("""{"a":1}""", actual[0]["var"])
+
def test_schema_of_csv(self):
with self.assertRaises(PySparkTypeError) as pe:
F.schema_of_csv(1)
diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py
index 8431e9b3e35d4..140c7680b181b 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -374,6 +374,68 @@ def test_case_insensitive_dict(self):
self.assertEqual(d2["BaR"], 3)
self.assertEqual(d2["baz"], 3)
+ def test_arrow_batch_data_source(self):
+ import pyarrow as pa
+
+ class ArrowBatchDataSource(DataSource):
+ """
+ A data source for testing Arrow Batch Serialization
+ """
+
+ @classmethod
+ def name(cls):
+ return "arrowbatch"
+
+ def schema(self):
+ return "key int, value string"
+
+ def reader(self, schema: str):
+ return ArrowBatchDataSourceReader(schema, self.options)
+
+ class ArrowBatchDataSourceReader(DataSourceReader):
+ def __init__(self, schema, options):
+ self.schema: str = schema
+ self.options = options
+
+ def read(self, partition):
+ # Create Arrow Record Batch
+ keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
+ values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
+ schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
+ record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
+ yield record_batch
+
+ def partitions(self):
+ # hardcoded number of partitions
+ num_part = 1
+ return [InputPartition(i) for i in range(num_part)]
+
+ self.spark.dataSource.register(ArrowBatchDataSource)
+ df = self.spark.read.format("arrowbatch").load()
+ expected_data = [
+ Row(key=1, value="one"),
+ Row(key=2, value="two"),
+ Row(key=3, value="three"),
+ Row(key=4, value="four"),
+ Row(key=5, value="five"),
+ ]
+ assertDataFrameEqual(df, expected_data)
+
+ with self.assertRaisesRegex(
+ PythonException,
+ "PySparkRuntimeError: \\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\\] Return schema"
+ " mismatch in the result from 'read' method\\. Expected: 1 columns, Found: 2 columns",
+ ):
+ self.spark.read.format("arrowbatch").schema("dummy int").load().show()
+
+ with self.assertRaisesRegex(
+ PythonException,
+ "PySparkRuntimeError: \\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\\] Return schema mismatch"
+ " in the result from 'read' method\\. Expected: \\['key', 'dummy'\\] columns, Found:"
+ " \\['key', 'value'\\] columns",
+ ):
+ self.spark.read.format("arrowbatch").schema("key int, dummy string").load().show()
+
def test_data_source_type_mismatch(self):
class TestDataSource(DataSource):
@classmethod
diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py
index 183b0ad80d9d4..fa14b37b57e62 100644
--- a/python/pyspark/sql/tests/test_python_streaming_datasource.py
+++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py
@@ -152,6 +152,66 @@ def check_batch(df, batch_id):
q.awaitTermination()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
+ def test_stream_reader_pyarrow(self):
+ import pyarrow as pa
+
+ class TestStreamReader(DataSourceStreamReader):
+ def initialOffset(self):
+ return {"offset": 0}
+
+ def latestOffset(self):
+ return {"offset": 2}
+
+ def partitions(self, start, end):
+ # hardcoded number of partitions
+ num_part = 1
+ return [InputPartition(i) for i in range(num_part)]
+
+ def read(self, partition):
+ keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
+ values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
+ schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
+ record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
+ yield record_batch
+
+ class TestDataSourcePyarrow(DataSource):
+ @classmethod
+ def name(cls):
+ return "testdatasourcepyarrow"
+
+ def schema(self):
+ return "key int, value string"
+
+ def streamReader(self, schema):
+ return TestStreamReader()
+
+ self.spark.dataSource.register(TestDataSourcePyarrow)
+ df = self.spark.readStream.format("testdatasourcepyarrow").load()
+
+ output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output")
+ checkpoint_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_checkpoint")
+
+ q = (
+ df.writeStream.format("json")
+ .option("checkpointLocation", checkpoint_dir.name)
+ .start(output_dir.name)
+ )
+ while not q.recentProgress:
+ time.sleep(0.2)
+ q.stop()
+ q.awaitTermination()
+
+ expected_data = [
+ Row(key=1, value="one"),
+ Row(key=2, value="two"),
+ Row(key=3, value="three"),
+ Row(key=4, value="four"),
+ Row(key=5, value="five"),
+ ]
+ df = self.spark.read.json(output_dir.name)
+
+ assertDataFrameEqual(df, expected_data)
+
def test_simple_stream_reader(self):
class SimpleStreamReader(SimpleDataSourceStreamReader):
def initialOffset(self):
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index a97734ee0fcef..5d9ec92cbc830 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -41,6 +41,7 @@
PythonException,
UnknownException,
SparkUpgradeException,
+ PySparkImportError,
PySparkNotImplementedError,
PySparkRuntimeError,
)
@@ -115,6 +116,22 @@ def require_test_compiled() -> None:
)
+def require_minimum_plotly_version() -> None:
+ """Raise ImportError if plotly is not installed"""
+ minimum_plotly_version = "4.8"
+
+ try:
+ import plotly # noqa: F401
+ except ImportError as error:
+ raise PySparkImportError(
+ errorClass="PACKAGE_NOT_INSTALLED",
+ messageParameters={
+ "package_name": "plotly",
+ "minimum_version": str(minimum_plotly_version),
+ },
+ ) from error
+
+
class ForeachBatchFunction:
"""
This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
@@ -336,7 +353,7 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
def dispatch_df_method(f: FuncT) -> FuncT:
"""
- For the usecases of direct DataFrame.union(df, ...), it checks if self
+ For the use cases of direct DataFrame.method(df, ...), it checks if self
is a Connect DataFrame or Classic DataFrame, and dispatches.
"""
@@ -363,8 +380,8 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
def dispatch_col_method(f: FuncT) -> FuncT:
"""
- For the usecases of direct Column.method(col, ...), it checks if self
- is a Connect DataFrame or Classic DataFrame, and dispatches.
+ For the use cases of direct Column.method(col, ...), it checks if self
+ is a Connect Column or Classic Column, and dispatches.
"""
@functools.wraps(f)
@@ -390,8 +407,9 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
def dispatch_window_method(f: FuncT) -> FuncT:
"""
- For the usecases of direct Window.method(col, ...), it checks if self
- is a Connect Window or Classic Window, and dispatches.
+ For use cases of direct Window.method(col, ...), this function dispatches
+ the call to either ConnectWindow or ClassicWindow based on the execution
+ environment.
"""
@functools.wraps(f)
@@ -405,11 +423,6 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
return getattr(ClassicWindow, f.__name__)(*args, **kwargs)
- raise PySparkNotImplementedError(
- errorClass="NOT_IMPLEMENTED",
- messageParameters={"feature": f"Window.{f.__name__}"},
- )
-
return cast(FuncT, wrapped)
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py
index 16b98ac0ed1e4..2af25fb52f150 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -19,7 +19,7 @@
import sys
import functools
import pyarrow as pa
-from itertools import islice
+from itertools import islice, chain
from typing import IO, List, Iterator, Iterable, Tuple, Union
from pyspark.accumulators import _accumulatorRegistry
@@ -59,21 +59,18 @@
def records_to_arrow_batches(
- output_iter: Iterator[Tuple],
+ output_iter: Union[Iterator[Tuple], Iterator[pa.RecordBatch]],
max_arrow_batch_size: int,
return_type: StructType,
data_source: DataSource,
) -> Iterable[pa.RecordBatch]:
"""
- Convert an iterator of Python tuples to an iterator of pyarrow record batches.
-
- For each python tuple, check the types of each field and append it to the records batch.
-
+ First check if the iterator yields PyArrow's `pyarrow.RecordBatch`, if so, yield
+ them directly. Otherwise, convert an iterator of Python tuples to an iterator
+ of pyarrow record batches. For each Python tuple, check the types of each field
+ and append it to the records batch.
"""
- def batched(iterator: Iterator, n: int) -> Iterator:
- return iter(functools.partial(lambda it: list(islice(it, n)), iterator), [])
-
pa_schema = to_arrow_schema(return_type)
column_names = return_type.fieldNames()
column_converters = [
@@ -83,6 +80,45 @@ def batched(iterator: Iterator, n: int) -> Iterator:
num_cols = len(column_names)
col_mapping = {name: i for i, name in enumerate(column_names)}
col_name_set = set(column_names)
+
+ try:
+ first_element = next(output_iter)
+ except StopIteration:
+ return
+
+ # If the first element is of type pa.RecordBatch yield all elements and return
+ if isinstance(first_element, pa.RecordBatch):
+ # Validate the schema, check the RecordBatch column count
+ num_columns = first_element.num_columns
+ if num_columns != num_cols:
+ raise PySparkRuntimeError(
+ errorClass="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
+ messageParameters={
+ "expected": str(num_cols),
+ "actual": str(num_columns),
+ },
+ )
+ for name in column_names:
+ if name not in first_element.schema.names:
+ raise PySparkRuntimeError(
+ errorClass="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
+ messageParameters={
+ "expected": str(column_names),
+ "actual": str(first_element.schema.names),
+ },
+ )
+
+ yield first_element
+ for element in output_iter:
+ yield element
+ return
+
+ # Put the first element back to the iterator
+ output_iter = chain([first_element], output_iter)
+
+ def batched(iterator: Iterator, n: int) -> Iterator:
+ return iter(functools.partial(lambda it: list(islice(it, n)), iterator), [])
+
for batch in batched(output_iter, max_arrow_batch_size):
pylist: List[List] = [[] for _ in range(num_cols)]
for result in batch:
@@ -103,7 +139,8 @@ def batched(iterator: Iterator, n: int) -> Iterator:
messageParameters={
"type": type(result).__name__,
"name": data_source.name(),
- "supported_types": "tuple, list, `pyspark.sql.types.Row`",
+ "supported_types": "tuple, list, `pyspark.sql.types.Row`,"
+ " `pyarrow.RecordBatch`",
},
)
@@ -145,9 +182,10 @@ def main(infile: IO, outfile: IO) -> None:
This process then creates a `DataSourceReader` instance by calling the `reader` method
on the `DataSource` instance. Then it calls the `partitions()` method of the reader and
- constructs a Python UDTF using the `read()` method of the reader.
+ constructs a PyArrow's `RecordBatch` with the data using the `read()` method of the reader.
- The partition values and the UDTF are then serialized and sent back to the JVM via the socket.
+ The partition values and the Arrow Batch are then serialized and sent back to the JVM
+ via the socket.
"""
try:
check_python_version(infile)
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 9f07c44c084cf..00ad40e68bd7c 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -48,6 +48,13 @@
except Exception as e:
test_not_compiled_message = str(e)
+plotly_requirement_message = None
+try:
+ import plotly
+except ImportError as e:
+ plotly_requirement_message = str(e)
+have_plotly = plotly_requirement_message is None
+
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index f33f2111c5a1e..5488d11d868f5 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -185,7 +185,7 @@ def setUpClass(cls):
def tearDownClass(cls):
cls.sc.stop()
- def test_assert_vanilla_mode(self):
+ def test_assert_classic_mode(self):
from pyspark.sql import is_remote
self.assertFalse(is_remote())
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 205e3d957a415..cca44435efe67 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -262,10 +262,6 @@ def try_simplify_traceback(tb: TracebackType) -> Optional[TracebackType]:
if "pypy" in platform.python_implementation().lower():
# Traceback modification is not supported with PyPy in PySpark.
return None
- if sys.version_info[:2] < (3, 7):
- # Traceback creation is not supported Python < 3.7.
- # See https://bugs.python.org/issue30579.
- return None
import pyspark
@@ -791,7 +787,7 @@ def is_remote_only() -> bool:
if __name__ == "__main__":
- if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 7):
+ if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 9):
import doctest
import pyspark.util
from pyspark.core.context import SparkContext
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
index fa0fd454ccc44..211c6c93b9674 100644
--- a/resource-managers/kubernetes/core/pom.xml
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -29,15 +29,11 @@
Spark Project Kuberneteskubernetes
- **/*Volcano*.scalavolcano
-
-
- io.fabric8
@@ -50,6 +46,40 @@
${kubernetes-client.version}
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-volcano-source
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+ add-volcano-test-sources
+ generate-test-sources
+
+ add-test-source
+
+
+
+
+
+
+
+
+
+
+
@@ -151,19 +181,6 @@
-
-
-
- net.alchim31.maven
- scala-maven-plugin
-
-
- ${volcano.exclude}
-
-
-
-
- target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
index 393ffc5674011..3a4d68c19014d 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -776,7 +776,7 @@ private[spark] object Config extends Logging {
val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium"
val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit"
val KUBERNETES_VOLUMES_OPTIONS_SERVER_KEY = "options.server"
-
+ val KUBERNETES_VOLUMES_LABEL_KEY = "label."
val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv."
val KUBERNETES_DNS_SUBDOMAIN_NAME_MAX_LENGTH = 253
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
index 3f7355de18911..9dfd40a773eb1 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala
@@ -24,7 +24,8 @@ private[spark] case class KubernetesHostPathVolumeConf(hostPath: String)
private[spark] case class KubernetesPVCVolumeConf(
claimName: String,
storageClass: Option[String] = None,
- size: Option[String] = None)
+ size: Option[String] = None,
+ labels: Option[Map[String, String]] = None)
extends KubernetesVolumeSpecificConf
private[spark] case class KubernetesEmptyDirVolumeConf(
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
index ee2108e8234d3..6463512c0114b 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala
@@ -45,13 +45,21 @@ object KubernetesVolumeUtils {
val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY"
val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY"
val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY"
+ val labelKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_LABEL_KEY"
+
+ val volumeLabelsMap = properties
+ .filter(_._1.startsWith(labelKey))
+ .map {
+ case (k, v) => k.replaceAll(labelKey, "") -> v
+ }
KubernetesVolumeSpec(
volumeName = volumeName,
mountPath = properties(pathKey),
mountSubPath = properties.getOrElse(subPathKey, ""),
mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean),
- volumeConf = parseVolumeSpecificConf(properties, volumeType, volumeName))
+ volumeConf = parseVolumeSpecificConf(properties,
+ volumeType, volumeName, Option(volumeLabelsMap)))
}.toSeq
}
@@ -74,7 +82,8 @@ object KubernetesVolumeUtils {
private def parseVolumeSpecificConf(
options: Map[String, String],
volumeType: String,
- volumeName: String): KubernetesVolumeSpecificConf = {
+ volumeName: String,
+ labels: Option[Map[String, String]]): KubernetesVolumeSpecificConf = {
volumeType match {
case KUBERNETES_VOLUMES_HOSTPATH_TYPE =>
val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY"
@@ -91,7 +100,8 @@ object KubernetesVolumeUtils {
KubernetesPVCVolumeConf(
options(claimNameKey),
options.get(storageClassKey),
- options.get(sizeLimitKey))
+ options.get(sizeLimitKey),
+ labels)
case KUBERNETES_VOLUMES_EMPTYDIR_TYPE =>
val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY"
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
index 72cc012a6bdd0..5cc61c746b0e0 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala
@@ -74,7 +74,7 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf)
new VolumeBuilder()
.withHostPath(new HostPathVolumeSource(hostPath, ""))
- case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size) =>
+ case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size, labels) =>
val claimName = conf match {
case c: KubernetesExecutorConf =>
claimNameTemplate
@@ -86,12 +86,17 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf)
.replaceAll(PVC_ON_DEMAND, s"${conf.resourceNamePrefix}-driver$PVC_POSTFIX-$i")
}
if (storageClass.isDefined && size.isDefined) {
+ val defaultVolumeLabels = Map(SPARK_APP_ID_LABEL -> conf.appId)
+ val volumeLabels = labels match {
+ case Some(customLabelsMap) => (customLabelsMap ++ defaultVolumeLabels).asJava
+ case None => defaultVolumeLabels.asJava
+ }
additionalResources.append(new PersistentVolumeClaimBuilder()
.withKind(PVC)
.withApiVersion("v1")
.withNewMetadata()
.withName(claimName)
- .addToLabels(SPARK_APP_ID_LABEL, conf.appId)
+ .addToLabels(volumeLabels)
.endMetadata()
.withNewSpec()
.withStorageClassName(storageClass.get)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala
index b70b9348d23b4..7e0a65bcdda90 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala
@@ -117,12 +117,17 @@ object KubernetesTestConf {
(KUBERNETES_VOLUMES_HOSTPATH_TYPE,
Map(KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> path))
- case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit) =>
+ case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit, labels) =>
val sconf = storageClass
.map { s => (KUBERNETES_VOLUMES_OPTIONS_CLAIM_STORAGE_CLASS_KEY, s) }.toMap
val lconf = sizeLimit.map { l => (KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY, l) }.toMap
+ val llabels = labels match {
+ case Some(value) => value.map { case(k, v) => s"label.$k" -> v }
+ case None => Map()
+ }
(KUBERNETES_VOLUMES_PVC_TYPE,
- Map(KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY -> claimName) ++ sconf ++ lconf)
+ Map(KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY -> claimName) ++
+ sconf ++ lconf ++ llabels)
case KubernetesEmptyDirVolumeConf(medium, sizeLimit) =>
val mconf = medium.map { m => (KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY, m) }.toMap
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
index fdc1aae0d4109..5c103739d3082 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala
@@ -56,7 +56,39 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite {
assert(volumeSpec.mountPath === "/path")
assert(volumeSpec.mountReadOnly)
assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] ===
- KubernetesPVCVolumeConf("claimName"))
+ KubernetesPVCVolumeConf("claimName", labels = Some(Map())))
+ }
+
+ test("SPARK-49598: Parses persistentVolumeClaim volumes correctly with labels") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.label.env", "test")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.label.foo", "bar")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head
+ assert(volumeSpec.volumeName === "volumeName")
+ assert(volumeSpec.mountPath === "/path")
+ assert(volumeSpec.mountReadOnly)
+ assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] ===
+ KubernetesPVCVolumeConf(claimName = "claimName",
+ labels = Some(Map("env" -> "test", "foo" -> "bar"))))
+ }
+
+ test("SPARK-49598: Parses persistentVolumeClaim volumes & puts " +
+ "labels as empty Map if not provided") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true")
+ sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName")
+
+ val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head
+ assert(volumeSpec.volumeName === "volumeName")
+ assert(volumeSpec.mountPath === "/path")
+ assert(volumeSpec.mountReadOnly)
+ assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] ===
+ KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map())))
}
test("Parses emptyDir volumes correctly") {
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
index 54796def95e53..6a68898c5f61c 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala
@@ -131,6 +131,79 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite {
assert(pvcClaim.getClaimName.endsWith("-driver-pvc-0"))
}
+ test("SPARK-49598: Create and mounts persistentVolumeClaims in driver with labels") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ "",
+ true,
+ KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND,
+ storageClass = Some("gp3"),
+ size = Some("1Mi"),
+ labels = Some(Map("foo" -> "bar", "env" -> "test")))
+ )
+
+ val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf))
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 1)
+ val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim
+ assert(pvcClaim.getClaimName.endsWith("-driver-pvc-0"))
+ }
+
+ test("SPARK-49598: Create and mounts persistentVolumeClaims in executors with labels") {
+ val volumeConf = KubernetesVolumeSpec(
+ "testVolume",
+ "/tmp",
+ "",
+ true,
+ KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND,
+ storageClass = Some("gp3"),
+ size = Some("1Mi"),
+ labels = Some(Map("foo1" -> "bar1", "env" -> "exec-test")))
+ )
+
+ val executorConf = KubernetesTestConf.createExecutorConf(volumes = Seq(volumeConf))
+ val executorStep = new MountVolumesFeatureStep(executorConf)
+ val executorPod = executorStep.configurePod(SparkPod.initialPod())
+
+ assert(executorPod.pod.getSpec.getVolumes.size() === 1)
+ val executorPVC = executorPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim
+ assert(executorPVC.getClaimName.endsWith("-exec-1-pvc-0"))
+ }
+
+ test("SPARK-49598: Mount multiple volumes to executor with labels") {
+ val pvcVolumeConf1 = KubernetesVolumeSpec(
+ "checkpointVolume1",
+ "/checkpoints1",
+ "",
+ true,
+ KubernetesPVCVolumeConf(claimName = "pvcClaim1",
+ storageClass = Some("gp3"),
+ size = Some("1Mi"),
+ labels = Some(Map("foo1" -> "bar1", "env1" -> "exec-test-1")))
+ )
+
+ val pvcVolumeConf2 = KubernetesVolumeSpec(
+ "checkpointVolume2",
+ "/checkpoints2",
+ "",
+ true,
+ KubernetesPVCVolumeConf(claimName = "pvcClaim2",
+ storageClass = Some("gp3"),
+ size = Some("1Mi"),
+ labels = Some(Map("foo2" -> "bar2", "env2" -> "exec-test-2")))
+ )
+
+ val kubernetesConf = KubernetesTestConf.createExecutorConf(
+ volumes = Seq(pvcVolumeConf1, pvcVolumeConf2))
+ val step = new MountVolumesFeatureStep(kubernetesConf)
+ val configuredPod = step.configurePod(SparkPod.initialPod())
+
+ assert(configuredPod.pod.getSpec.getVolumes.size() === 2)
+ assert(configuredPod.container.getVolumeMounts.size() === 2)
+ }
+
test("Create and mount persistentVolumeClaims in executors") {
val volumeConf = KubernetesVolumeSpec(
"testVolume",
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala b/resource-managers/kubernetes/core/volcano/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala
similarity index 100%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala
rename to resource-managers/kubernetes/core/volcano/src/main/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStep.scala
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala b/resource-managers/kubernetes/core/volcano/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala
similarity index 100%
rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala
rename to resource-managers/kubernetes/core/volcano/src/test/scala/org/apache/spark/deploy/k8s/features/VolcanoFeatureStepSuite.scala
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
index 518c5bc217071..45ce25b8e037a 100644
--- a/resource-managers/kubernetes/integration-tests/pom.xml
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -46,7 +46,6 @@
org.apache.spark.deploy.k8s.integrationtest.YuniKornTag
- **/*Volcano*.scalajarSpark Project Kubernetes Integration Tests
@@ -83,19 +82,6 @@
-
-
-
- net.alchim31.maven
- scala-maven-plugin
-
-
- ${volcano.exclude}
-
-
-
-
- org.codehaus.mojo
@@ -219,9 +205,6 @@
volcano
-
-
- io.fabric8
@@ -229,6 +212,28 @@
${kubernetes-client.version}
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-volcano-test-sources
+ generate-test-sources
+
+ add-test-source
+
+
+
+
+
+
+
+
+
+
+
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala b/resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala
similarity index 100%
rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala
rename to resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoSuite.scala
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala b/resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala
similarity index 100%
rename from resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala
rename to resource-managers/kubernetes/integration-tests/volcano/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/VolcanoTestsSuite.scala
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index acfc0011f5d05..de28041acd41f 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS';
BY: 'BY';
BYTE: 'BYTE';
CACHE: 'CACHE';
+CALL: 'CALL';
CALLED: 'CALLED';
CASCADE: 'CASCADE';
CASE: 'CASE';
@@ -255,12 +256,14 @@ BINARY_HEX: 'X';
HOUR: 'HOUR';
HOURS: 'HOURS';
IDENTIFIER_KW: 'IDENTIFIER';
+IDENTITY: 'IDENTITY';
IF: 'IF';
IGNORE: 'IGNORE';
IMMEDIATE: 'IMMEDIATE';
IMPORT: 'IMPORT';
IN: 'IN';
INCLUDE: 'INCLUDE';
+INCREMENT: 'INCREMENT';
INDEX: 'INDEX';
INDEXES: 'INDEXES';
INNER: 'INNER';
@@ -276,6 +279,7 @@ INTO: 'INTO';
INVOKER: 'INVOKER';
IS: 'IS';
ITEMS: 'ITEMS';
+ITERATE: 'ITERATE';
JOIN: 'JOIN';
KEYS: 'KEYS';
LANGUAGE: 'LANGUAGE';
@@ -283,6 +287,7 @@ LAST: 'LAST';
LATERAL: 'LATERAL';
LAZY: 'LAZY';
LEADING: 'LEADING';
+LEAVE: 'LEAVE';
LEFT: 'LEFT';
LIKE: 'LIKE';
ILIKE: 'ILIKE';
@@ -362,6 +367,7 @@ REFERENCES: 'REFERENCES';
REFRESH: 'REFRESH';
RENAME: 'RENAME';
REPAIR: 'REPAIR';
+REPEAT: 'REPEAT';
REPEATABLE: 'REPEATABLE';
REPLACE: 'REPLACE';
RESET: 'RESET';
@@ -451,6 +457,7 @@ UNKNOWN: 'UNKNOWN';
UNLOCK: 'UNLOCK';
UNPIVOT: 'UNPIVOT';
UNSET: 'UNSET';
+UNTIL: 'UNTIL';
UPDATE: 'UPDATE';
USE: 'USE';
USER: 'USER';
@@ -501,6 +508,7 @@ TILDE: '~';
AMPERSAND: '&';
PIPE: '|';
CONCAT_PIPE: '||';
+OPERATOR_PIPE: '|>';
HAT: '^';
COLON: ':';
DOUBLE_COLON: '::';
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 5b8805821b045..094f7f5315b80 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -64,7 +64,11 @@ compoundStatement
| setStatementWithOptionalVarKeyword
| beginEndCompoundBlock
| ifElseStatement
+ | caseStatement
| whileStatement
+ | repeatStatement
+ | leaveStatement
+ | iterateStatement
;
setStatementWithOptionalVarKeyword
@@ -83,6 +87,25 @@ ifElseStatement
(ELSE elseBody=compoundBody)? END IF
;
+repeatStatement
+ : beginLabel? REPEAT compoundBody UNTIL booleanExpression END REPEAT endLabel?
+ ;
+
+leaveStatement
+ : LEAVE multipartIdentifier
+ ;
+
+iterateStatement
+ : ITERATE multipartIdentifier
+ ;
+
+caseStatement
+ : CASE (WHEN conditions+=booleanExpression THEN conditionalBodies+=compoundBody)+
+ (ELSE elseBody=compoundBody)? END CASE #searchedCaseStatement
+ | CASE caseVariable=expression (WHEN conditionExpressions+=expression THEN conditionalBodies+=compoundBody)+
+ (ELSE elseBody=compoundBody)? END CASE #simpleCaseStatement
+ ;
+
singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
@@ -275,6 +298,10 @@ statement
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)? #createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex
+ | CALL identifierReference
+ LEFT_PAREN
+ (functionArgument (COMMA functionArgument)*)?
+ RIGHT_PAREN #call
| unsupportedHiveNativeCommands .*? #failNativeCommand
;
@@ -589,6 +616,7 @@ queryTerm
operator=INTERSECT setQuantifier? right=queryTerm #setOperation
| left=queryTerm {!legacy_setops_precedence_enabled}?
operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation
+ | left=queryTerm OPERATOR_PIPE operatorPipeRightSide #operatorPipeStatement
;
queryPrimary
@@ -1272,7 +1300,22 @@ colDefinitionOption
;
generationExpression
- : GENERATED ALWAYS AS LEFT_PAREN expression RIGHT_PAREN
+ : GENERATED ALWAYS AS LEFT_PAREN expression RIGHT_PAREN #generatedColumn
+ | GENERATED (ALWAYS | BY DEFAULT) AS IDENTITY identityColSpec? #identityColumn
+ ;
+
+identityColSpec
+ : LEFT_PAREN sequenceGeneratorOption* RIGHT_PAREN
+ ;
+
+sequenceGeneratorOption
+ : START WITH start=sequenceGeneratorStartOrStep
+ | INCREMENT BY step=sequenceGeneratorStartOrStep
+ ;
+
+sequenceGeneratorStartOrStep
+ : MINUS? INTEGER_VALUE
+ | MINUS? BIGINT_LITERAL
;
complexColTypeList
@@ -1447,6 +1490,11 @@ version
| stringLit
;
+operatorPipeRightSide
+ : selectClause
+ | whereClause
+ ;
+
// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL.
// - Reserved keywords:
// Keywords that are reserved and can't be used as identifiers for table, view, column,
@@ -1562,11 +1610,13 @@ ansiNonReserved
| HOUR
| HOURS
| IDENTIFIER_KW
+ | IDENTITY
| IF
| IGNORE
| IMMEDIATE
| IMPORT
| INCLUDE
+ | INCREMENT
| INDEX
| INDEXES
| INPATH
@@ -1578,10 +1628,12 @@ ansiNonReserved
| INTERVAL
| INVOKER
| ITEMS
+ | ITERATE
| KEYS
| LANGUAGE
| LAST
| LAZY
+ | LEAVE
| LIKE
| ILIKE
| LIMIT
@@ -1648,6 +1700,7 @@ ansiNonReserved
| REFRESH
| RENAME
| REPAIR
+ | REPEAT
| REPEATABLE
| REPLACE
| RESET
@@ -1723,6 +1776,7 @@ ansiNonReserved
| UNLOCK
| UNPIVOT
| UNSET
+ | UNTIL
| UPDATE
| USE
| VALUES
@@ -1802,6 +1856,7 @@ nonReserved
| BY
| BYTE
| CACHE
+ | CALL
| CALLED
| CASCADE
| CASE
@@ -1908,12 +1963,14 @@ nonReserved
| HOUR
| HOURS
| IDENTIFIER_KW
+ | IDENTITY
| IF
| IGNORE
| IMMEDIATE
| IMPORT
| IN
| INCLUDE
+ | INCREMENT
| INDEX
| INDEXES
| INPATH
@@ -1927,11 +1984,13 @@ nonReserved
| INVOKER
| IS
| ITEMS
+ | ITERATE
| KEYS
| LANGUAGE
| LAST
| LAZY
| LEADING
+ | LEAVE
| LIKE
| LONG
| ILIKE
@@ -2009,6 +2068,7 @@ nonReserved
| REFRESH
| RENAME
| REPAIR
+ | REPEAT
| REPEATABLE
| REPLACE
| RESET
@@ -2093,6 +2153,7 @@ nonReserved
| UNLOCK
| UNPIVOT
| UNSET
+ | UNTIL
| UPDATE
| USE
| USER
diff --git a/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentityColumnSpec.java b/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentityColumnSpec.java
new file mode 100644
index 0000000000000..4a8943736bd31
--- /dev/null
+++ b/sql/api/src/main/java/org/apache/spark/sql/connector/catalog/IdentityColumnSpec.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog;
+import org.apache.spark.annotation.Evolving;
+
+import java.util.Objects;
+
+/**
+ * Identity column specification.
+ */
+@Evolving
+public class IdentityColumnSpec {
+ private final long start;
+ private final long step;
+ private final boolean allowExplicitInsert;
+
+ /**
+ * Creates an identity column specification.
+ * @param start the start value to generate the identity values
+ * @param step the step value to generate the identity values
+ * @param allowExplicitInsert whether the identity column allows explicit insertion of values
+ */
+ public IdentityColumnSpec(long start, long step, boolean allowExplicitInsert) {
+ this.start = start;
+ this.step = step;
+ this.allowExplicitInsert = allowExplicitInsert;
+ }
+
+ /**
+ * @return the start value to generate the identity values
+ */
+ public long getStart() {
+ return start;
+ }
+
+ /**
+ * @return the step value to generate the identity values
+ */
+ public long getStep() {
+ return step;
+ }
+
+ /**
+ * @return whether the identity column allows explicit insertion of values
+ */
+ public boolean isAllowExplicitInsert() {
+ return allowExplicitInsert;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ IdentityColumnSpec that = (IdentityColumnSpec) o;
+ return start == that.start &&
+ step == that.step &&
+ allowExplicitInsert == that.allowExplicitInsert;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(start, step, allowExplicitInsert);
+ }
+
+ @Override
+ public String toString() {
+ return "IdentityColumnSpec{" +
+ "start=" + start +
+ ", step=" + step +
+ ", allowExplicitInsert=" + allowExplicitInsert +
+ "}";
+ }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala
index 7a428f6cc3288..a2c1f2cc41f8f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
* @since 1.3.0
*/
@Stable
-class AnalysisException protected(
+class AnalysisException protected (
val message: String,
val line: Option[Int] = None,
val startPosition: Option[Int] = None,
@@ -37,12 +37,12 @@ class AnalysisException protected(
val errorClass: Option[String] = None,
val messageParameters: Map[String, String] = Map.empty,
val context: Array[QueryContext] = Array.empty)
- extends Exception(message, cause.orNull) with SparkThrowable with Serializable with WithOrigin {
+ extends Exception(message, cause.orNull)
+ with SparkThrowable
+ with Serializable
+ with WithOrigin {
- def this(
- errorClass: String,
- messageParameters: Map[String, String],
- cause: Option[Throwable]) =
+ def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) =
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
errorClass = Some(errorClass),
@@ -73,18 +73,10 @@ class AnalysisException protected(
cause = null,
context = context)
- def this(
- errorClass: String,
- messageParameters: Map[String, String]) =
- this(
- errorClass = errorClass,
- messageParameters = messageParameters,
- cause = None)
+ def this(errorClass: String, messageParameters: Map[String, String]) =
+ this(errorClass = errorClass, messageParameters = messageParameters, cause = None)
- def this(
- errorClass: String,
- messageParameters: Map[String, String],
- origin: Origin) =
+ def this(errorClass: String, messageParameters: Map[String, String], origin: Origin) =
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
line = origin.line,
@@ -115,8 +107,14 @@ class AnalysisException protected(
errorClass: Option[String] = this.errorClass,
messageParameters: Map[String, String] = this.messageParameters,
context: Array[QueryContext] = this.context): AnalysisException =
- new AnalysisException(message, line, startPosition, cause, errorClass,
- messageParameters, context)
+ new AnalysisException(
+ message,
+ line,
+ startPosition,
+ cause,
+ errorClass,
+ messageParameters,
+ context)
def withPosition(origin: Origin): AnalysisException = {
val newException = this.copy(
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala b/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala
index c78280af6e021..7e020df06fe47 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Artifact.scala
@@ -28,8 +28,7 @@ import org.apache.spark.sql.util.ArtifactUtils
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.MavenUtils
-
-private[sql] class Artifact private(val path: Path, val storage: LocalData) {
+private[sql] class Artifact private (val path: Path, val storage: LocalData) {
require(!path.isAbsolute, s"Bad path: $path")
lazy val size: Long = storage match {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
index cd6a04b2a0562..31ce44eca1684 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Column.scala
@@ -72,34 +72,34 @@ private[spark] object Column {
isDistinct: Boolean,
isInternal: Boolean,
inputs: Seq[Column]): Column = withOrigin {
- Column(internal.UnresolvedFunction(
- name,
- inputs.map(_.node),
- isDistinct = isDistinct,
- isInternal = isInternal))
+ Column(
+ internal.UnresolvedFunction(
+ name,
+ inputs.map(_.node),
+ isDistinct = isDistinct,
+ isInternal = isInternal))
}
}
/**
- * A [[Column]] where an [[Encoder]] has been given for the expected input and return type.
- * To create a [[TypedColumn]], use the `as` function on a [[Column]].
+ * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. To
+ * create a [[TypedColumn]], use the `as` function on a [[Column]].
*
- * @tparam T The input type expected for this expression. Can be `Any` if the expression is type
- * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
- * @tparam U The output type of this column.
+ * @tparam T
+ * The input type expected for this expression. Can be `Any` if the expression is type checked
+ * by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
+ * @tparam U
+ * The output type of this column.
*
* @since 1.6.0
*/
@Stable
-class TypedColumn[-T, U](
- node: ColumnNode,
- private[sql] val encoder: Encoder[U])
- extends Column(node) {
+class TypedColumn[-T, U](node: ColumnNode, private[sql] val encoder: Encoder[U])
+ extends Column(node) {
/**
- * Gives the [[TypedColumn]] a name (alias).
- * If the current `TypedColumn` has metadata associated with it, this metadata will be propagated
- * to the new column.
+ * Gives the [[TypedColumn]] a name (alias). If the current `TypedColumn` has metadata
+ * associated with it, this metadata will be propagated to the new column.
*
* @group expr_ops
* @since 2.0.0
@@ -168,23 +168,20 @@ class Column(val node: ColumnNode) extends Logging {
override def hashCode: Int = this.node.normalized.hashCode()
/**
- * Provides a type hint about the expected return value of this column. This information can
- * be used by operations such as `select` on a [[Dataset]] to automatically convert the
- * results into the correct JVM types.
+ * Provides a type hint about the expected return value of this column. This information can be
+ * used by operations such as `select` on a [[Dataset]] to automatically convert the results
+ * into the correct JVM types.
* @since 1.6.0
*/
- def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, implicitly[Encoder[U]])
+ def as[U: Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, implicitly[Encoder[U]])
/**
- * Extracts a value or values from a complex type.
- * The following types of extraction are supported:
- *
- *
Given an Array, an integer ordinal can be used to retrieve a single value.
- *
Given a Map, a key of the correct type can be used to retrieve an individual value.
- *
Given a Struct, a string fieldName can be used to extract that field.
- *
Given an Array of Structs, a string fieldName can be used to extract filed
- * of every struct in that array, and return an Array of fields.
- *
+ * Extracts a value or values from a complex type. The following types of extraction are
+ * supported:
Given an Array, an integer ordinal can be used to retrieve a single
+ * value.
Given a Map, a key of the correct type can be used to retrieve an individual
+ * value.
Given a Struct, a string fieldName can be used to extract that field.
+ *
Given an Array of Structs, a string fieldName can be used to extract filed of every
+ * struct in that array, and return an Array of fields.
* @group expr_ops
* @since 1.4.0
*/
@@ -283,8 +280,8 @@ class Column(val node: ColumnNode) extends Logging {
*
* @group expr_ops
* @since 2.0.0
- */
- def =!= (other: Any): Column = !(this === other)
+ */
+ def =!=(other: Any): Column = !(this === other)
/**
* Inequality test.
@@ -300,9 +297,9 @@ class Column(val node: ColumnNode) extends Logging {
*
* @group expr_ops
* @since 1.3.0
- */
+ */
@deprecated("!== does not have the same precedence as ===, use =!= instead", "2.0.0")
- def !== (other: Any): Column = this =!= other
+ def !==(other: Any): Column = this =!= other
/**
* Inequality test.
@@ -464,8 +461,8 @@ class Column(val node: ColumnNode) extends Logging {
def eqNullSafe(other: Any): Column = this <=> other
/**
- * Evaluates a list of conditions and returns one of multiple possible result expressions.
- * If otherwise is not defined at the end, null is returned for unmatched conditions.
+ * Evaluates a list of conditions and returns one of multiple possible result expressions. If
+ * otherwise is not defined at the end, null is returned for unmatched conditions.
*
* {{{
* // Example: encoding gender string column into integer.
@@ -489,8 +486,7 @@ class Column(val node: ColumnNode) extends Logging {
case internal.CaseWhenOtherwise(branches, None, _) =>
internal.CaseWhenOtherwise(branches :+ ((condition.node, lit(value).node)), None)
case internal.CaseWhenOtherwise(_, Some(_), _) =>
- throw new IllegalArgumentException(
- "when() cannot be applied once otherwise() is applied")
+ throw new IllegalArgumentException("when() cannot be applied once otherwise() is applied")
case _ =>
throw new IllegalArgumentException(
"when() can only be applied on a Column previously generated by when() function")
@@ -498,8 +494,8 @@ class Column(val node: ColumnNode) extends Logging {
}
/**
- * Evaluates a list of conditions and returns one of multiple possible result expressions.
- * If otherwise is not defined at the end, null is returned for unmatched conditions.
+ * Evaluates a list of conditions and returns one of multiple possible result expressions. If
+ * otherwise is not defined at the end, null is returned for unmatched conditions.
*
* {{{
* // Example: encoding gender string column into integer.
@@ -765,13 +761,11 @@ class Column(val node: ColumnNode) extends Logging {
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the evaluated values of the arguments.
*
- * Note: Since the type of the elements in the list are inferred only during the run time,
- * the elements will be "up-casted" to the most common type for comparison.
- * For eg:
- * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the
- * comparison will look like "String vs String".
- * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the
- * comparison will look like "Double vs Double"
+ * Note: Since the type of the elements in the list are inferred only during the run time, the
+ * elements will be "up-casted" to the most common type for comparison. For eg: 1) In the case
+ * of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look like
+ * "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted to
+ * "Double" and the comparison will look like "Double vs Double"
*
* @group expr_ops
* @since 1.5.0
@@ -784,12 +778,10 @@ class Column(val node: ColumnNode) extends Logging {
* by the provided collection.
*
* Note: Since the type of the elements in the collection are inferred only during the run time,
- * the elements will be "up-casted" to the most common type for comparison.
- * For eg:
- * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the
- * comparison will look like "String vs String".
- * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the
- * comparison will look like "Double vs Double"
+ * the elements will be "up-casted" to the most common type for comparison. For eg: 1) In the
+ * case of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look
+ * like "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted
+ * to "Double" and the comparison will look like "Double vs Double"
*
* @group expr_ops
* @since 2.4.0
@@ -801,12 +793,10 @@ class Column(val node: ColumnNode) extends Logging {
* by the provided collection.
*
* Note: Since the type of the elements in the collection are inferred only during the run time,
- * the elements will be "up-casted" to the most common type for comparison.
- * For eg:
- * 1) In the case of "Int vs String", the "Int" will be up-casted to "String" and the
- * comparison will look like "String vs String".
- * 2) In the case of "Float vs Double", the "Float" will be up-casted to "Double" and the
- * comparison will look like "Double vs Double"
+ * the elements will be "up-casted" to the most common type for comparison. For eg: 1) In the
+ * case of "Int vs String", the "Int" will be up-casted to "String" and the comparison will look
+ * like "String vs String". 2) In the case of "Float vs Double", the "Float" will be up-casted
+ * to "Double" and the comparison will look like "Double vs Double"
*
* @group java_expr_ops
* @since 2.4.0
@@ -822,8 +812,7 @@ class Column(val node: ColumnNode) extends Logging {
def like(literal: String): Column = fn("like", literal)
/**
- * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex
- * match.
+ * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match.
*
* @group expr_ops
* @since 1.3.0
@@ -839,8 +828,8 @@ class Column(val node: ColumnNode) extends Logging {
def ilike(literal: String): Column = fn("ilike", literal)
/**
- * An expression that gets an item at position `ordinal` out of an array,
- * or gets a value by key `key` in a `MapType`.
+ * An expression that gets an item at position `ordinal` out of an array, or gets a value by key
+ * `key` in a `MapType`.
*
* @group expr_ops
* @since 1.3.0
@@ -885,8 +874,8 @@ class Column(val node: ColumnNode) extends Logging {
* // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
* }}}
*
- * However, if you are going to add/replace multiple nested fields, it is more optimal to extract
- * out the nested struct before adding/replacing multiple fields e.g.
+ * However, if you are going to add/replace multiple nested fields, it is more optimal to
+ * extract out the nested struct before adding/replacing multiple fields e.g.
*
* {{{
* val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
@@ -906,8 +895,8 @@ class Column(val node: ColumnNode) extends Logging {
// scalastyle:off line.size.limit
/**
- * An expression that drops fields in `StructType` by name.
- * This is a no-op if schema doesn't contain field name(s).
+ * An expression that drops fields in `StructType` by name. This is a no-op if schema doesn't
+ * contain field name(s).
*
* {{{
* val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
@@ -951,8 +940,8 @@ class Column(val node: ColumnNode) extends Logging {
* // result: {"a":{"a":1}}
* }}}
*
- * However, if you are going to drop multiple nested fields, it is more optimal to extract
- * out the nested struct before dropping multiple fields from it e.g.
+ * However, if you are going to drop multiple nested fields, it is more optimal to extract out
+ * the nested struct before dropping multiple fields from it e.g.
*
* {{{
* val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col")
@@ -980,8 +969,10 @@ class Column(val node: ColumnNode) extends Logging {
/**
* An expression that returns a substring.
- * @param startPos expression for the starting position.
- * @param len expression for the length of the substring.
+ * @param startPos
+ * expression for the starting position.
+ * @param len
+ * expression for the length of the substring.
*
* @group expr_ops
* @since 1.3.0
@@ -990,8 +981,10 @@ class Column(val node: ColumnNode) extends Logging {
/**
* An expression that returns a substring.
- * @param startPos starting position.
- * @param len length of the substring.
+ * @param startPos
+ * starting position.
+ * @param len
+ * length of the substring.
*
* @group expr_ops
* @since 1.3.0
@@ -1057,9 +1050,9 @@ class Column(val node: ColumnNode) extends Logging {
* df.select($"colA".as("colB"))
* }}}
*
- * If the current column has metadata associated with it, this metadata will be propagated
- * to the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)`
- * with explicit metadata.
+ * If the current column has metadata associated with it, this metadata will be propagated to
+ * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
+ * explicit metadata.
*
* @group expr_ops
* @since 1.3.0
@@ -1097,9 +1090,9 @@ class Column(val node: ColumnNode) extends Logging {
* df.select($"colA".as("colB"))
* }}}
*
- * If the current column has metadata associated with it, this metadata will be propagated
- * to the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)`
- * with explicit metadata.
+ * If the current column has metadata associated with it, this metadata will be propagated to
+ * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
+ * explicit metadata.
*
* @group expr_ops
* @since 1.3.0
@@ -1126,9 +1119,9 @@ class Column(val node: ColumnNode) extends Logging {
* df.select($"colA".name("colB"))
* }}}
*
- * If the current column has metadata associated with it, this metadata will be propagated
- * to the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)`
- * with explicit metadata.
+ * If the current column has metadata associated with it, this metadata will be propagated to
+ * the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
+ * explicit metadata.
*
* @group expr_ops
* @since 2.0.0
@@ -1152,9 +1145,9 @@ class Column(val node: ColumnNode) extends Logging {
def cast(to: DataType): Column = Column(internal.Cast(node, to))
/**
- * Casts the column to a different data type, using the canonical string representation
- * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`,
- * `float`, `double`, `decimal`, `date`, `timestamp`.
+ * Casts the column to a different data type, using the canonical string representation of the
+ * type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, `float`,
+ * `double`, `decimal`, `date`, `timestamp`.
* {{{
* // Casts colA to integer.
* df.select(df("colA").cast("int"))
@@ -1224,8 +1217,8 @@ class Column(val node: ColumnNode) extends Logging {
def desc: Column = desc_nulls_last
/**
- * Returns a sort expression based on the descending order of the column,
- * and null values appear before non-null values.
+ * Returns a sort expression based on the descending order of the column, and null values appear
+ * before non-null values.
* {{{
* // Scala: sort a DataFrame by age column in descending order and null values appearing first.
* df.sort(df("age").desc_nulls_first)
@@ -1237,13 +1230,12 @@ class Column(val node: ColumnNode) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
- def desc_nulls_first: Column = sortOrder(
- internal.SortOrder.Descending,
- internal.SortOrder.NullsFirst)
+ def desc_nulls_first: Column =
+ sortOrder(internal.SortOrder.Descending, internal.SortOrder.NullsFirst)
/**
- * Returns a sort expression based on the descending order of the column,
- * and null values appear after non-null values.
+ * Returns a sort expression based on the descending order of the column, and null values appear
+ * after non-null values.
* {{{
* // Scala: sort a DataFrame by age column in descending order and null values appearing last.
* df.sort(df("age").desc_nulls_last)
@@ -1255,9 +1247,8 @@ class Column(val node: ColumnNode) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
- def desc_nulls_last: Column = sortOrder(
- internal.SortOrder.Descending,
- internal.SortOrder.NullsLast)
+ def desc_nulls_last: Column =
+ sortOrder(internal.SortOrder.Descending, internal.SortOrder.NullsLast)
/**
* Returns a sort expression based on ascending order of the column.
@@ -1275,8 +1266,8 @@ class Column(val node: ColumnNode) extends Logging {
def asc: Column = asc_nulls_first
/**
- * Returns a sort expression based on ascending order of the column,
- * and null values return before non-null values.
+ * Returns a sort expression based on ascending order of the column, and null values return
+ * before non-null values.
* {{{
* // Scala: sort a DataFrame by age column in ascending order and null values appearing first.
* df.sort(df("age").asc_nulls_first)
@@ -1288,13 +1279,12 @@ class Column(val node: ColumnNode) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
- def asc_nulls_first: Column = sortOrder(
- internal.SortOrder.Ascending,
- internal.SortOrder.NullsFirst)
+ def asc_nulls_first: Column =
+ sortOrder(internal.SortOrder.Ascending, internal.SortOrder.NullsFirst)
/**
- * Returns a sort expression based on ascending order of the column,
- * and null values appear after non-null values.
+ * Returns a sort expression based on ascending order of the column, and null values appear
+ * after non-null values.
* {{{
* // Scala: sort a DataFrame by age column in ascending order and null values appearing last.
* df.sort(df("age").asc_nulls_last)
@@ -1306,9 +1296,8 @@ class Column(val node: ColumnNode) extends Logging {
* @group expr_ops
* @since 2.1.0
*/
- def asc_nulls_last: Column = sortOrder(
- internal.SortOrder.Ascending,
- internal.SortOrder.NullsLast)
+ def asc_nulls_last: Column =
+ sortOrder(internal.SortOrder.Ascending, internal.SortOrder.NullsLast)
/**
* Prints the expression to the console for debugging purposes.
@@ -1378,8 +1367,8 @@ class Column(val node: ColumnNode) extends Logging {
}
/**
- * Defines an empty analytic clause. In this case the analytic function is applied
- * and presented for all rows in the result set.
+ * Defines an empty analytic clause. In this case the analytic function is applied and presented
+ * for all rows in the result set.
*
* {{{
* df.select(
@@ -1395,7 +1384,6 @@ class Column(val node: ColumnNode) extends Logging {
}
-
/**
* A convenient class used for constructing schema.
*
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 96855ee5ad164..1838c6bc8468f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -34,15 +34,11 @@ import org.apache.spark.sql.errors.CompilationErrors
abstract class DataFrameWriter[T] {
/**
- * Specifies the behavior when data or table already exists. Options include:
- *
- *
`SaveMode.Overwrite`: overwrite the existing data.
- *
`SaveMode.Append`: append the data.
- *
`SaveMode.Ignore`: ignore the operation (i.e. no-op).
- *
`SaveMode.ErrorIfExists`: throw an exception at runtime.
- *
- *
- * The default option is `ErrorIfExists`.
+ * Specifies the behavior when data or table already exists. Options include:
+ *
`SaveMode.Overwrite`: overwrite the existing data.
`SaveMode.Append`: append the
+ * data.
`SaveMode.Ignore`: ignore the operation (i.e. no-op).
+ *
`SaveMode.ErrorIfExists`: throw an exception at runtime.
The default
+ * option is `ErrorIfExists`.
*
* @since 1.4.0
*/
@@ -52,13 +48,10 @@ abstract class DataFrameWriter[T] {
}
/**
- * Specifies the behavior when data or table already exists. Options include:
- *
- *
`overwrite`: overwrite the existing data.
- *
`append`: append the data.
- *
`ignore`: ignore the operation (i.e. no-op).
- *
`error` or `errorifexists`: default option, throw an exception at runtime.
- *
+ * Specifies the behavior when data or table already exists. Options include:
+ *
`overwrite`: overwrite the existing data.
`append`: append the data.
+ *
`ignore`: ignore the operation (i.e. no-op).
`error` or `errorifexists`: default
+ * option, throw an exception at runtime.
*
* @since 1.4.0
*/
@@ -85,8 +78,8 @@ abstract class DataFrameWriter[T] {
/**
* Adds an output option for the underlying data source.
*
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
*
* @since 1.4.0
*/
@@ -98,8 +91,8 @@ abstract class DataFrameWriter[T] {
/**
* Adds an output option for the underlying data source.
*
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
@@ -108,8 +101,8 @@ abstract class DataFrameWriter[T] {
/**
* Adds an output option for the underlying data source.
*
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
@@ -118,8 +111,8 @@ abstract class DataFrameWriter[T] {
/**
* Adds an output option for the underlying data source.
*
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
*
* @since 2.0.0
*/
@@ -128,8 +121,8 @@ abstract class DataFrameWriter[T] {
/**
* (Scala-specific) Adds output options for the underlying data source.
*
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
*
* @since 1.4.0
*/
@@ -141,8 +134,8 @@ abstract class DataFrameWriter[T] {
/**
* Adds output options for the underlying data source.
*
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
*
* @since 1.4.0
*/
@@ -154,16 +147,13 @@ abstract class DataFrameWriter[T] {
/**
* Partitions the output by the given columns on the file system. If specified, the output is
* laid out on the file system similar to Hive's partitioning scheme. As an example, when we
- * partition a dataset by year and then month, the directory layout would look like:
- *
- *
year=2016/month=01/
- *
year=2016/month=02/
- *
+ * partition a dataset by year and then month, the directory layout would look like:
+ *
year=2016/month=01/
year=2016/month=02/
*
- * Partitioning is one of the most widely used techniques to optimize physical data layout.
- * It provides a coarse-grained index for skipping unnecessary data reads when queries have
- * predicates on the partitioned columns. In order for partitioning to work well, the number
- * of distinct values in each column should typically be less than tens of thousands.
+ * Partitioning is one of the most widely used techniques to optimize physical data layout. It
+ * provides a coarse-grained index for skipping unnecessary data reads when queries have
+ * predicates on the partitioned columns. In order for partitioning to work well, the number of
+ * distinct values in each column should typically be less than tens of thousands.
*
* This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark
* 2.1.0.
@@ -179,8 +169,8 @@ abstract class DataFrameWriter[T] {
/**
* Buckets the output by the given columns. If specified, the output is laid out on the file
- * system similar to Hive's bucketing scheme, but with a different bucket hash function
- * and is not compatible with Hive's bucketing.
+ * system similar to Hive's bucketing scheme, but with a different bucket hash function and is
+ * not compatible with Hive's bucketing.
*
* This is applicable for all file-based data sources (e.g. Parquet, JSON) starting with Spark
* 2.1.0.
@@ -241,13 +231,15 @@ abstract class DataFrameWriter[T] {
def save(): Unit
/**
- * Inserts the content of the `DataFrame` to the specified table. It requires that
- * the schema of the `DataFrame` is the same as the schema of the table.
+ * Inserts the content of the `DataFrame` to the specified table. It requires that the schema of
+ * the `DataFrame` is the same as the schema of the table.
*
- * @note Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based
- * resolution. For example:
- * @note SaveMode.ErrorIfExists and SaveMode.Ignore behave as SaveMode.Append in `insertInto` as
- * `insertInto` is not a table creating operation.
+ * @note
+ * Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based
+ * resolution. For example:
+ * @note
+ * SaveMode.ErrorIfExists and SaveMode.Ignore behave as SaveMode.Append in `insertInto` as
+ * `insertInto` is not a table creating operation.
*
* {{{
* scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1")
@@ -263,7 +255,7 @@ abstract class DataFrameWriter[T] {
* +---+---+
* }}}
*
- * Because it inserts data to an existing table, format or options will be ignored.
+ * Because it inserts data to an existing table, format or options will be ignored.
* @since 1.4.0
*/
def insertInto(tableName: String): Unit
@@ -271,15 +263,15 @@ abstract class DataFrameWriter[T] {
/**
* Saves the content of the `DataFrame` as the specified table.
*
- * In the case the table already exists, behavior of this function depends on the
- * save mode, specified by the `mode` function (default to throwing an exception).
- * When `mode` is `Overwrite`, the schema of the `DataFrame` does not need to be
- * the same as that of the existing table.
+ * In the case the table already exists, behavior of this function depends on the save mode,
+ * specified by the `mode` function (default to throwing an exception). When `mode` is
+ * `Overwrite`, the schema of the `DataFrame` does not need to be the same as that of the
+ * existing table.
*
* When `mode` is `Append`, if there is an existing table, we will use the format and options of
* the existing table. The column order in the schema of the `DataFrame` doesn't need to be same
- * as that of the existing table. Unlike `insertInto`, `saveAsTable` will use the column names to
- * find the correct column positions. For example:
+ * as that of the existing table. Unlike `insertInto`, `saveAsTable` will use the column names
+ * to find the correct column positions. For example:
*
* {{{
* scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1")
@@ -293,10 +285,10 @@ abstract class DataFrameWriter[T] {
* +---+---+
* }}}
*
- * In this method, save mode is used to determine the behavior if the data source table exists in
- * Spark catalog. We will always overwrite the underlying data of data source (e.g. a table in
- * JDBC data source) if the table doesn't exist in Spark catalog, and will always append to the
- * underlying data of data source if the table already exists.
+ * In this method, save mode is used to determine the behavior if the data source table exists
+ * in Spark catalog. We will always overwrite the underlying data of data source (e.g. a table
+ * in JDBC data source) if the table doesn't exist in Spark catalog, and will always append to
+ * the underlying data of data source if the table already exists.
*
* When the DataFrame is created from a non-partitioned `HadoopFsRelation` with a single input
* path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC
@@ -310,25 +302,25 @@ abstract class DataFrameWriter[T] {
/**
* Saves the content of the `DataFrame` to an external database table via JDBC. In the case the
- * table already exists in the external database, behavior of this function depends on the
- * save mode, specified by the `mode` function (default to throwing an exception).
+ * table already exists in the external database, behavior of this function depends on the save
+ * mode, specified by the `mode` function (default to throwing an exception).
*
* Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
* your external database systems.
*
- * JDBC-specific option and parameter documentation for storing tables via JDBC in
- *
- * Data Source Option in the version you use.
- *
- * @param table Name of the table in the external database.
- * @param connectionProperties JDBC database connection arguments, a list of arbitrary string
- * tag/value. Normally at least a "user" and "password" property
- * should be included. "batchsize" can be used to control the
- * number of rows per insert. "isolationLevel" can be one of
- * "NONE", "READ_COMMITTED", "READ_UNCOMMITTED", "REPEATABLE_READ",
- * or "SERIALIZABLE", corresponding to standard transaction
- * isolation levels defined by JDBC's Connection object, with default
- * of "READ_UNCOMMITTED".
+ * JDBC-specific option and parameter documentation for storing tables via JDBC in
+ * Data Source Option in the version you use.
+ *
+ * @param table
+ * Name of the table in the external database.
+ * @param connectionProperties
+ * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
+ * a "user" and "password" property should be included. "batchsize" can be used to control the
+ * number of rows per insert. "isolationLevel" can be one of "NONE", "READ_COMMITTED",
+ * "READ_UNCOMMITTED", "REPEATABLE_READ", or "SERIALIZABLE", corresponding to standard
+ * transaction isolation levels defined by JDBC's Connection object, with default of
+ * "READ_UNCOMMITTED".
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: util.Properties): Unit = {
@@ -343,16 +335,16 @@ abstract class DataFrameWriter[T] {
}
/**
- * Saves the content of the `DataFrame` in JSON format (
- * JSON Lines text format or newline-delimited JSON) at the specified path.
- * This is equivalent to:
+ * Saves the content of the `DataFrame` in JSON format ( JSON
+ * Lines text format or newline-delimited JSON) at the specified path. This is equivalent
+ * to:
* {{{
* format("json").save(path)
* }}}
*
- * You can find the JSON-specific options for writing JSON files in
- *
- * Data Source Option in the version you use.
+ * You can find the JSON-specific options for writing JSON files in
+ * Data Source Option in the version you use.
*
* @since 1.4.0
*/
@@ -361,16 +353,15 @@ abstract class DataFrameWriter[T] {
}
/**
- * Saves the content of the `DataFrame` in Parquet format at the specified path.
- * This is equivalent to:
+ * Saves the content of the `DataFrame` in Parquet format at the specified path. This is
+ * equivalent to:
* {{{
* format("parquet").save(path)
* }}}
*
- * Parquet-specific option(s) for writing Parquet files can be found in
- *
- * Data Source Option in the version you use.
+ * Parquet-specific option(s) for writing Parquet files can be found in Data
+ * Source Option in the version you use.
*
* @since 1.4.0
*/
@@ -379,16 +370,15 @@ abstract class DataFrameWriter[T] {
}
/**
- * Saves the content of the `DataFrame` in ORC format at the specified path.
- * This is equivalent to:
+ * Saves the content of the `DataFrame` in ORC format at the specified path. This is equivalent
+ * to:
* {{{
* format("orc").save(path)
* }}}
*
- * ORC-specific option(s) for writing ORC files can be found in
- *
- * Data Source Option in the version you use.
+ * ORC-specific option(s) for writing ORC files can be found in Data
+ * Source Option in the version you use.
*
* @since 1.5.0
*/
@@ -397,9 +387,9 @@ abstract class DataFrameWriter[T] {
}
/**
- * Saves the content of the `DataFrame` in a text file at the specified path.
- * The DataFrame must have only one column that is of string type.
- * Each row becomes a new line in the output file. For example:
+ * Saves the content of the `DataFrame` in a text file at the specified path. The DataFrame must
+ * have only one column that is of string type. Each row becomes a new line in the output file.
+ * For example:
* {{{
* // Scala:
* df.write.text("/path/to/output")
@@ -409,9 +399,9 @@ abstract class DataFrameWriter[T] {
* }}}
* The text files will be encoded as UTF-8.
*
- * You can find the text-specific options for writing text files in
- *
- * Data Source Option in the version you use.
+ * You can find the text-specific options for writing text files in
+ * Data Source Option in the version you use.
*
* @since 1.6.0
*/
@@ -420,15 +410,15 @@ abstract class DataFrameWriter[T] {
}
/**
- * Saves the content of the `DataFrame` in CSV format at the specified path.
- * This is equivalent to:
+ * Saves the content of the `DataFrame` in CSV format at the specified path. This is equivalent
+ * to:
* {{{
* format("csv").save(path)
* }}}
*
- * You can find the CSV-specific options for writing CSV files in
- *
- * Data Source Option in the version you use.
+ * You can find the CSV-specific options for writing CSV files in
+ * Data Source Option in the version you use.
*
* @since 2.0.0
*/
@@ -437,31 +427,25 @@ abstract class DataFrameWriter[T] {
}
/**
- * Saves the content of the `DataFrame` in XML format at the specified path.
- * This is equivalent to:
+ * Saves the content of the `DataFrame` in XML format at the specified path. This is equivalent
+ * to:
* {{{
* format("xml").save(path)
* }}}
*
- * Note that writing a XML file from `DataFrame` having a field `ArrayType` with
- * its element as `ArrayType` would have an additional nested field for the element.
- * For example, the `DataFrame` having a field below,
+ * Note that writing a XML file from `DataFrame` having a field `ArrayType` with its element as
+ * `ArrayType` would have an additional nested field for the element. For example, the
+ * `DataFrame` having a field below,
*
- * {@code fieldA [[data1], [data2]]}
+ * {@code fieldA [[data1], [data2]]}
*
- * would produce a XML file below.
- * {@code
- *
- * data1
- *
- *
- * data2
- * }
+ * would produce a XML file below. {@code data1
+ * data2}
*
* Namely, roundtrip in writing and reading can end up in different schema structure.
*
- * You can find the XML-specific options for writing XML files in
- *
+ * You can find the XML-specific options for writing XML files in
* Data Source Option in the version you use.
*/
def xml(path: String): Unit = {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
similarity index 62%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
index 3f9b224003914..37a29c2e4b66d 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
@@ -14,111 +14,80 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql
-import scala.collection.mutable
-import scala.jdk.CollectionConverters._
+import java.util
import org.apache.spark.annotation.Experimental
-import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException}
/**
- * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2
+ * Interface used to write a [[org.apache.spark.sql.api.Dataset]] to external storage using the v2
* API.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
@Experimental
-final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
- extends CreateTableWriter[T] {
- import ds.sparkSession.RichColumn
-
- private var provider: Option[String] = None
-
- private val options = new mutable.HashMap[String, String]()
+abstract class DataFrameWriterV2[T] extends CreateTableWriter[T] {
- private val properties = new mutable.HashMap[String, String]()
+ /** @inheritdoc */
+ override def using(provider: String): this.type
- private var partitioning: Option[Seq[proto.Expression]] = None
+ /** @inheritdoc */
+ override def option(key: String, value: Boolean): this.type = option(key, value.toString)
- private var clustering: Option[Seq[String]] = None
+ /** @inheritdoc */
+ override def option(key: String, value: Long): this.type = option(key, value.toString)
- private var overwriteCondition: Option[proto.Expression] = None
+ /** @inheritdoc */
+ override def option(key: String, value: Double): this.type = option(key, value.toString)
- override def using(provider: String): CreateTableWriter[T] = {
- this.provider = Some(provider)
- this
- }
+ /** @inheritdoc */
+ override def option(key: String, value: String): this.type
- override def option(key: String, value: String): DataFrameWriterV2[T] = {
- this.options.put(key, value)
- this
- }
+ /** @inheritdoc */
+ override def options(options: scala.collection.Map[String, String]): this.type
- override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = {
- options.foreach { case (key, value) =>
- this.options.put(key, value)
- }
- this
- }
+ /** @inheritdoc */
+ override def options(options: util.Map[String, String]): this.type
- override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = {
- this.options(options.asScala)
- this
- }
-
- override def tableProperty(property: String, value: String): CreateTableWriter[T] = {
- this.properties.put(property, value)
- this
- }
+ /** @inheritdoc */
+ override def tableProperty(property: String, value: String): this.type
+ /** @inheritdoc */
@scala.annotation.varargs
- override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = {
- val asTransforms = (column +: columns).map(_.expr)
- this.partitioning = Some(asTransforms)
- this
- }
+ override def partitionedBy(column: Column, columns: Column*): this.type
+ /** @inheritdoc */
@scala.annotation.varargs
- override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = {
- this.clustering = Some(colName +: colNames)
- this
- }
-
- override def create(): Unit = {
- executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE)
- }
-
- override def replace(): Unit = {
- executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE)
- }
-
- override def createOrReplace(): Unit = {
- executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
- }
+ override def clusterBy(colName: String, colNames: String*): this.type
/**
* Append the contents of the data frame to the output table.
*
- * If the output table does not exist, this operation will fail. The data frame will be
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
* validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+ * If the table does not exist
*/
- def append(): Unit = {
- executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND)
- }
+ @throws(classOf[NoSuchTableException])
+ def append(): Unit
/**
* Overwrite rows matching the given filter condition with the contents of the data frame in the
* output table.
*
- * If the output table does not exist, this operation will fail. The data frame will be
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
* validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+ * If the table does not exist
*/
- def overwrite(condition: Column): Unit = {
- overwriteCondition = Some(condition.expr)
- executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
- }
+ @throws(classOf[NoSuchTableException])
+ def overwrite(condition: Column): Unit
/**
* Overwrite all partition for which the data frame contains at least one row with the contents
@@ -127,85 +96,64 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
* This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces
* partitions dynamically depending on the contents of the data frame.
*
- * If the output table does not exist, this operation will fail. The data frame will be
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
* validated to ensure it is compatible with the existing table.
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+ * If the table does not exist
*/
- def overwritePartitions(): Unit = {
- executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS)
- }
-
- private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = {
- val builder = proto.WriteOperationV2.newBuilder()
-
- builder.setInput(ds.plan.getRoot)
- builder.setTableName(table)
- provider.foreach(builder.setProvider)
-
- partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava))
- clustering.foreach(columns => builder.addAllClusteringColumns(columns.asJava))
-
- options.foreach { case (k, v) =>
- builder.putOptions(k, v)
- }
- properties.foreach { case (k, v) =>
- builder.putTableProperties(k, v)
- }
-
- builder.setMode(mode)
-
- overwriteCondition.foreach(builder.setOverwriteCondition)
-
- ds.sparkSession.execute(proto.Command.newBuilder().setWriteOperationV2(builder).build())
- }
+ @throws(classOf[NoSuchTableException])
+ def overwritePartitions(): Unit
}
/**
* Configuration methods common to create/replace operations and insert/overwrite operations.
* @tparam R
* builder type to return
- * @since 3.4.0
+ * @since 3.0.0
*/
trait WriteConfigMethods[R] {
/**
* Add a write option.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
def option(key: String, value: String): R
/**
* Add a boolean output option.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
def option(key: String, value: Boolean): R = option(key, value.toString)
/**
* Add a long output option.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
def option(key: String, value: Long): R = option(key, value.toString)
/**
* Add a double output option.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
def option(key: String, value: Double): R = option(key, value.toString)
/**
* Add write options from a Scala Map.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
def options(options: scala.collection.Map[String, String]): R
/**
* Add write options from a Java Map.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
def options(options: java.util.Map[String, String]): R
}
@@ -213,7 +161,7 @@ trait WriteConfigMethods[R] {
/**
* Trait to restrict calls to create and replace operations.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
@@ -223,8 +171,13 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
* The new table's schema, partition layout, properties, and other configuration will be based
* on the configuration set on this writer.
*
- * If the output table exists, this operation will fail.
+ * If the output table exists, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]].
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+ * If the table already exists
*/
+ @throws(classOf[TableAlreadyExistsException])
def create(): Unit
/**
@@ -233,8 +186,13 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
* The existing table's schema, partition layout, properties, and other configuration will be
* replaced with the contents of the data frame and the configuration set on this writer.
*
- * If the output table does not exist, this operation will fail.
+ * If the output table does not exist, this operation will fail with
+ * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]].
+ *
+ * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
+ * If the table does not exist
*/
+ @throws(classOf[CannotReplaceMissingTableException])
def replace(): Unit
/**
@@ -260,7 +218,7 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
* predicates on the partitioned columns. In order for partitioning to work well, the number of
* distinct values in each column should typically be less than tens of thousands.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
@scala.annotation.varargs
def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T]
@@ -282,7 +240,7 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
* Specifies a provider for the underlying output data source. Spark's default catalog supports
* "parquet", "json", etc.
*
- * @since 3.4.0
+ * @since 3.0.0
*/
def using(provider: String): CreateTableWriter[T]
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala
index ea760d80541c8..d125e89b8c410 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types._
/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
*
- * == Scala ==
+ * ==Scala==
* Encoders are generally created automatically through implicits from a `SparkSession`, or can be
* explicitly created by calling static methods on [[Encoders]].
*
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
* val ds = Seq(1, 2, 3).toDS() // implicitly provided (spark.implicits.newIntEncoder)
* }}}
*
- * == Java ==
+ * ==Java==
* Encoders are specified by calling static methods on [[Encoders]].
*
* {{{
@@ -57,8 +57,8 @@ import org.apache.spark.sql.types._
* Encoders.bean(MyClass.class);
* }}}
*
- * == Implementation ==
- * - Encoders should be thread-safe.
+ * ==Implementation==
+ * - Encoders should be thread-safe.
*
* @since 1.6.0
*/
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
similarity index 72%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
index ffd9975770066..9976b34f7a01f 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -14,95 +14,99 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.spark.sql
-import scala.reflect.ClassTag
+import java.lang.reflect.Modifier
+
+import scala.reflect.{classTag, ClassTag}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, JavaSerializationCodec, RowEncoder => RowEncoderFactory}
+import org.apache.spark.sql.catalyst.encoders.{Codec, JavaSerializationCodec, KryoSerializationCodec, RowEncoder => SchemaInference}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.sql.types._
/**
* Methods for creating an [[Encoder]].
*
- * @since 3.5.0
+ * @since 1.6.0
*/
object Encoders {
/**
* An encoder for nullable boolean type. The Scala primitive encoder is available as
* [[scalaBoolean]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def BOOLEAN: Encoder[java.lang.Boolean] = BoxedBooleanEncoder
/**
* An encoder for nullable byte type. The Scala primitive encoder is available as [[scalaByte]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def BYTE: Encoder[java.lang.Byte] = BoxedByteEncoder
/**
* An encoder for nullable short type. The Scala primitive encoder is available as
* [[scalaShort]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def SHORT: Encoder[java.lang.Short] = BoxedShortEncoder
/**
* An encoder for nullable int type. The Scala primitive encoder is available as [[scalaInt]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def INT: Encoder[java.lang.Integer] = BoxedIntEncoder
/**
* An encoder for nullable long type. The Scala primitive encoder is available as [[scalaLong]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def LONG: Encoder[java.lang.Long] = BoxedLongEncoder
/**
* An encoder for nullable float type. The Scala primitive encoder is available as
* [[scalaFloat]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def FLOAT: Encoder[java.lang.Float] = BoxedFloatEncoder
/**
* An encoder for nullable double type. The Scala primitive encoder is available as
* [[scalaDouble]].
- * @since 3.5.0
+ * @since 1.6.0
*/
def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder
/**
* An encoder for nullable string type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def STRING: Encoder[java.lang.String] = StringEncoder
/**
* An encoder for nullable decimal type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def DECIMAL: Encoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER
/**
* An encoder for nullable date type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
- def DATE: Encoder[java.sql.Date] = DateEncoder(lenientSerialization = false)
+ def DATE: Encoder[java.sql.Date] = STRICT_DATE_ENCODER
/**
* Creates an encoder that serializes instances of the `java.time.LocalDate` class to the
* internal representation of nullable Catalyst's DateType.
*
- * @since 3.5.0
+ * @since 3.0.0
*/
def LOCALDATE: Encoder[java.time.LocalDate] = STRICT_LOCAL_DATE_ENCODER
@@ -110,14 +114,14 @@ object Encoders {
* Creates an encoder that serializes instances of the `java.time.LocalDateTime` class to the
* internal representation of nullable Catalyst's TimestampNTZType.
*
- * @since 3.5.0
+ * @since 3.4.0
*/
def LOCALDATETIME: Encoder[java.time.LocalDateTime] = LocalDateTimeEncoder
/**
* An encoder for nullable timestamp type.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def TIMESTAMP: Encoder[java.sql.Timestamp] = STRICT_TIMESTAMP_ENCODER
@@ -125,14 +129,14 @@ object Encoders {
* Creates an encoder that serializes instances of the `java.time.Instant` class to the internal
* representation of nullable Catalyst's TimestampType.
*
- * @since 3.5.0
+ * @since 3.0.0
*/
def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER
/**
* An encoder for arrays of bytes.
*
- * @since 3.5.0
+ * @since 1.6.1
*/
def BINARY: Encoder[Array[Byte]] = BinaryEncoder
@@ -140,7 +144,7 @@ object Encoders {
* Creates an encoder that serializes instances of the `java.time.Duration` class to the
* internal representation of nullable Catalyst's DayTimeIntervalType.
*
- * @since 3.5.0
+ * @since 3.2.0
*/
def DURATION: Encoder[java.time.Duration] = DayTimeIntervalEncoder
@@ -148,7 +152,7 @@ object Encoders {
* Creates an encoder that serializes instances of the `java.time.Period` class to the internal
* representation of nullable Catalyst's YearMonthIntervalType.
*
- * @since 3.5.0
+ * @since 3.2.0
*/
def PERIOD: Encoder[java.time.Period] = YearMonthIntervalEncoder
@@ -166,7 +170,7 @@ object Encoders {
* - collection types: array, java.util.List, and map
* - nested java bean.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def bean[T](beanClass: Class[T]): Encoder[T] = JavaTypeInference.encoderFor(beanClass)
@@ -175,7 +179,27 @@ object Encoders {
*
* @since 3.5.0
*/
- def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema)
+ def row(schema: StructType): Encoder[Row] = SchemaInference.encoderFor(schema)
+
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. This
+ * encoder maps T into a single byte array (binary) field.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(KryoSerializationCodec)
+
+ /**
+ * Creates an encoder that serializes objects of type T using Kryo. This encoder maps T into a
+ * single byte array (binary) field.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
/**
* (Scala-specific) Creates an encoder that serializes objects of type T using generic Java
@@ -185,11 +209,10 @@ object Encoders {
*
* @note
* This is extremely inefficient and should only be used as the last resort.
- * @since 4.0.0
+ *
+ * @since 1.6.0
*/
- def javaSerialization[T: ClassTag]: Encoder[T] = {
- TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder, JavaSerializationCodec)
- }
+ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(JavaSerializationCodec)
/**
* Creates an encoder that serializes objects of type T using generic Java serialization. This
@@ -199,25 +222,53 @@ object Encoders {
*
* @note
* This is extremely inefficient and should only be used as the last resort.
- * @since 4.0.0
+ *
+ * @since 1.6.0
*/
- def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))
+ def javaSerialization[T](clazz: Class[T]): Encoder[T] =
+ javaSerialization(ClassTag[T](clazz))
+
+ /** Throws an exception if T is not a public class. */
+ private def validatePublicClass[T: ClassTag](): Unit = {
+ if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
+ throw ExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName)
+ }
+ }
+
+ /** A way to construct encoders using generic serializers. */
+ private def genericSerializer[T: ClassTag](
+ provider: () => Codec[Any, Array[Byte]]): Encoder[T] = {
+ if (classTag[T].runtimeClass.isPrimitive) {
+ throw ExecutionErrors.primitiveTypesNotSupportedError()
+ }
+
+ validatePublicClass[T]()
+
+ TransformingEncoder(classTag[T], BinaryEncoder, provider)
+ }
- private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
- ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]]
+ private[sql] def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = {
+ ProductEncoder.tuple(encoders.map(agnosticEncoderFor(_))).asInstanceOf[Encoder[T]]
}
+ /**
+ * An encoder for 1-ary tuples.
+ *
+ * @since 4.0.0
+ */
+ def tuple[T1](e1: Encoder[T1]): Encoder[(T1)] = tupleEncoder(e1)
+
/**
* An encoder for 2-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2](e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = tupleEncoder(e1, e2)
/**
* An encoder for 3-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2, T3](
e1: Encoder[T1],
@@ -227,7 +278,7 @@ object Encoders {
/**
* An encoder for 4-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2, T3, T4](
e1: Encoder[T1],
@@ -238,7 +289,7 @@ object Encoders {
/**
* An encoder for 5-ary tuples.
*
- * @since 3.5.0
+ * @since 1.6.0
*/
def tuple[T1, T2, T3, T4, T5](
e1: Encoder[T1],
@@ -249,49 +300,50 @@ object Encoders {
/**
* An encoder for Scala's product type (tuples, case classes, etc).
- * @since 3.5.0
+ * @since 2.0.0
*/
def product[T <: Product: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T]
/**
* An encoder for Scala's primitive int type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaInt: Encoder[Int] = PrimitiveIntEncoder
/**
* An encoder for Scala's primitive long type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaLong: Encoder[Long] = PrimitiveLongEncoder
/**
* An encoder for Scala's primitive double type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaDouble: Encoder[Double] = PrimitiveDoubleEncoder
/**
* An encoder for Scala's primitive float type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaFloat: Encoder[Float] = PrimitiveFloatEncoder
/**
* An encoder for Scala's primitive byte type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaByte: Encoder[Byte] = PrimitiveByteEncoder
/**
* An encoder for Scala's primitive short type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaShort: Encoder[Short] = PrimitiveShortEncoder
/**
* An encoder for Scala's primitive boolean type.
- * @since 3.5.0
+ * @since 2.0.0
*/
def scalaBoolean: Encoder[Boolean] = PrimitiveBooleanEncoder
+
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
similarity index 99%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
index 756356f7f0282..4e2bb35146a63 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
@@ -89,7 +89,7 @@ package org.apache.spark.sql
* });
* }}}
*
- * @since 3.5.0
+ * @since 2.0.0
*/
abstract class ForeachWriter[T] extends Serializable {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
similarity index 65%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
index 71813af1e354f..db56b39e28aeb 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
@@ -14,46 +14,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql
-import scala.jdk.CollectionConverters._
-
-import org.apache.spark.SparkRuntimeException
import org.apache.spark.annotation.Experimental
-import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{Expression, MergeIntoTableCommand}
-import org.apache.spark.connect.proto.MergeAction
-import org.apache.spark.sql.functions.expr
/**
* `MergeIntoWriter` provides methods to define and execute merge actions based on specified
* conditions.
*
+ * Please note that schema evolution is disabled by default.
+ *
* @tparam T
* the type of data in the Dataset.
- * @param table
- * the name of the target table for the merge operation.
- * @param ds
- * the source Dataset to merge into the target table.
- * @param on
- * the merge condition.
- * @param schemaEvolutionEnabled
- * whether to enable automatic schema evolution for this merge operation. Default is `false`.
- *
* @since 4.0.0
*/
@Experimental
-class MergeIntoWriter[T] private[sql] (
- table: String,
- ds: Dataset[T],
- on: Column,
- schemaEvolutionEnabled: Boolean = false) {
- import ds.sparkSession.RichColumn
+abstract class MergeIntoWriter[T] {
+ private var schemaEvolution: Boolean = false
- private[sql] var matchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedBySourceActions: Seq[MergeAction] = Seq.empty[MergeAction]
+ private[sql] def schemaEvolutionEnabled: Boolean = schemaEvolution
/**
* Initialize a `WhenMatched` action without any condition.
@@ -176,84 +155,39 @@ class MergeIntoWriter[T] private[sql] (
/**
* Enable automatic schema evolution for this merge operation.
+ *
* @return
* A `MergeIntoWriter` instance with schema evolution enabled.
*/
def withSchemaEvolution(): MergeIntoWriter[T] = {
- new MergeIntoWriter[T](this.table, this.ds, this.on, schemaEvolutionEnabled = true)
- .withNewMatchedActions(this.matchedActions: _*)
- .withNewNotMatchedActions(this.notMatchedActions: _*)
- .withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*)
+ schemaEvolution = true
+ this
}
/**
* Executes the merge operation.
*/
- def merge(): Unit = {
- if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
- throw new SparkRuntimeException(
- errorClass = "NO_MERGE_ACTION_SPECIFIED",
- messageParameters = Map.empty)
- }
+ def merge(): Unit
- val matchedActionExpressions =
- matchedActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build())
- val notMatchedActionExpressions =
- notMatchedActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build())
- val notMatchedBySourceActionExpressions =
- notMatchedBySourceActions.map(Expression.newBuilder().setMergeAction(_)).map(_.build())
- val mergeIntoCommand = MergeIntoTableCommand
- .newBuilder()
- .setTargetTableName(table)
- .setSourceTablePlan(ds.plan.getRoot)
- .setMergeCondition(on.expr)
- .addAllMatchActions(matchedActionExpressions.asJava)
- .addAllNotMatchedActions(notMatchedActionExpressions.asJava)
- .addAllNotMatchedBySourceActions(notMatchedBySourceActionExpressions.asJava)
- .setWithSchemaEvolution(schemaEvolutionEnabled)
- .build()
+ // Action callbacks.
+ protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T]
- ds.sparkSession.execute(
- proto.Command
- .newBuilder()
- .setMergeIntoTableCommand(mergeIntoCommand)
- .build())
- }
+ protected[sql] def insert(
+ condition: Option[Column],
+ map: Map[String, Column]): MergeIntoWriter[T]
- private[sql] def withNewMatchedActions(action: MergeAction*): MergeIntoWriter[T] = {
- this.matchedActions = this.matchedActions :++ action
- this
- }
+ protected[sql] def updateAll(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T]
- private[sql] def withNewNotMatchedActions(action: MergeAction*): MergeIntoWriter[T] = {
- this.notMatchedActions = this.notMatchedActions :++ action
- this
- }
+ protected[sql] def update(
+ condition: Option[Column],
+ map: Map[String, Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T]
- private[sql] def withNewNotMatchedBySourceActions(action: MergeAction*): MergeIntoWriter[T] = {
- this.notMatchedBySourceActions = this.notMatchedBySourceActions :++ action
- this
- }
-
- private[sql] def buildMergeAction(
- actionType: MergeAction.ActionType,
- conditionOpt: Option[Column],
- assignmentsOpt: Option[Map[String, Column]] = None): MergeAction = {
- val assignmentsProtoOpt = assignmentsOpt.map {
- _.map { case (k, v) =>
- MergeAction.Assignment
- .newBuilder()
- .setKey(expr(k).expr)
- .setValue(v.expr)
- .build()
- }.toSeq.asJava
- }
-
- val builder = MergeAction.newBuilder().setActionType(actionType)
- conditionOpt.map(c => builder.setCondition(c.expr))
- assignmentsProtoOpt.map(builder.addAllAssignments)
- builder.build()
- }
+ protected[sql] def delete(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T]
}
/**
@@ -265,7 +199,6 @@ class MergeIntoWriter[T] private[sql] (
* @param condition
* An optional condition Expression that specifies when the actions should be applied. If the
* condition is None, the actions will be applied to all matched rows.
- *
* @tparam T
* The type of data in the MergeIntoWriter.
*/
@@ -279,11 +212,8 @@ case class WhenMatched[T] private[sql] (
* @return
* The MergeIntoWriter instance with the update all action configured.
*/
- def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR, condition))
- }
+ def updateAll(): MergeIntoWriter[T] =
+ mergeIntoWriter.updateAll(condition, notMatchedBySource = false)
/**
* Specifies an action to update matched rows in the DataFrame with the provided column
@@ -294,11 +224,8 @@ case class WhenMatched[T] private[sql] (
* @return
* The MergeIntoWriter instance with the update action configured.
*/
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE, condition, Some(map)))
- }
+ def update(map: Map[String, Column]): MergeIntoWriter[T] =
+ mergeIntoWriter.update(condition, map, notMatchedBySource = false)
/**
* Specifies an action to delete matched rows from the DataFrame.
@@ -306,10 +233,8 @@ case class WhenMatched[T] private[sql] (
* @return
* The MergeIntoWriter instance with the delete action configured.
*/
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
- mergeIntoWriter.buildMergeAction(MergeAction.ActionType.ACTION_TYPE_DELETE, condition))
- }
+ def delete(): MergeIntoWriter[T] =
+ mergeIntoWriter.delete(condition, notMatchedBySource = false)
}
/**
@@ -335,11 +260,8 @@ case class WhenNotMatched[T] private[sql] (
* @return
* The MergeIntoWriter instance with the insert all action configured.
*/
- def insertAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_INSERT_STAR, condition))
- }
+ def insertAll(): MergeIntoWriter[T] =
+ mergeIntoWriter.insertAll(condition)
/**
* Specifies an action to insert non-matched rows into the DataFrame with the provided column
@@ -350,11 +272,8 @@ case class WhenNotMatched[T] private[sql] (
* @return
* The MergeIntoWriter instance with the insert action configured.
*/
- def insert(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_INSERT, condition, Some(map)))
- }
+ def insert(map: Map[String, Column]): MergeIntoWriter[T] =
+ mergeIntoWriter.insert(condition, map)
}
/**
@@ -379,11 +298,8 @@ case class WhenNotMatchedBySource[T] private[sql] (
* @return
* The MergeIntoWriter instance with the update all action configured.
*/
- def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE_STAR, condition))
- }
+ def updateAll(): MergeIntoWriter[T] =
+ mergeIntoWriter.updateAll(condition, notMatchedBySource = true)
/**
* Specifies an action to update non-matched rows in the target DataFrame with the provided
@@ -394,11 +310,8 @@ case class WhenNotMatchedBySource[T] private[sql] (
* @return
* The MergeIntoWriter instance with the update action configured.
*/
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_UPDATE, condition, Some(map)))
- }
+ def update(map: Map[String, Column]): MergeIntoWriter[T] =
+ mergeIntoWriter.update(condition, map, notMatchedBySource = true)
/**
* Specifies an action to delete non-matched rows from the target DataFrame when not matched by
@@ -407,9 +320,6 @@ case class WhenNotMatchedBySource[T] private[sql] (
* @return
* The MergeIntoWriter instance with the delete action configured.
*/
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- mergeIntoWriter
- .buildMergeAction(MergeAction.ActionType.ACTION_TYPE_DELETE, condition))
- }
+ def delete(): MergeIntoWriter[T] =
+ mergeIntoWriter.delete(condition, notMatchedBySource = true)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index 02f5a8de1e3f6..fa427fe651907 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -38,13 +38,14 @@ import org.apache.spark.util.SparkThreadUtils
* val metrics = observation.get
* }}}
*
- * This collects the metrics while the first action is executed on the observed dataset. Subsequent
- * actions do not modify the metrics returned by [[get]]. Retrieval of the metric via [[get]]
- * blocks until the first action has finished and metrics become available.
+ * This collects the metrics while the first action is executed on the observed dataset.
+ * Subsequent actions do not modify the metrics returned by [[get]]. Retrieval of the metric via
+ * [[get]] blocks until the first action has finished and metrics become available.
*
* This class does not support streaming datasets.
*
- * @param name name of the metric
+ * @param name
+ * name of the metric
* @since 3.3.0
*/
class Observation(val name: String) {
@@ -65,23 +66,27 @@ class Observation(val name: String) {
val future: Future[Map[String, Any]] = promise.future
/**
- * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish
- * its first action. Only the result of the first action is available. Subsequent actions do not
+ * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish its
+ * first action. Only the result of the first action is available. Subsequent actions do not
* modify the result.
*
- * @return the observed metrics as a `Map[String, Any]`
- * @throws InterruptedException interrupted while waiting
+ * @return
+ * the observed metrics as a `Map[String, Any]`
+ * @throws InterruptedException
+ * interrupted while waiting
*/
@throws[InterruptedException]
def get: Map[String, Any] = SparkThreadUtils.awaitResult(future, Duration.Inf)
/**
- * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish
- * its first action. Only the result of the first action is available. Subsequent actions do not
+ * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish its
+ * first action. Only the result of the first action is available. Subsequent actions do not
* modify the result.
*
- * @return the observed metrics as a `java.util.Map[String, Object]`
- * @throws InterruptedException interrupted while waiting
+ * @return
+ * the observed metrics as a `java.util.Map[String, Object]`
+ * @throws InterruptedException
+ * interrupted while waiting
*/
@throws[InterruptedException]
def getAsJava: java.util.Map[String, Any] = get.asJava
@@ -89,7 +94,8 @@ class Observation(val name: String) {
/**
* Get the observed metrics. This returns the metrics if they are available, otherwise an empty.
*
- * @return the observed metrics as a `Map[String, Any]`
+ * @return
+ * the observed metrics as a `Map[String, Any]`
*/
@throws[InterruptedException]
private[sql] def getOrEmpty: Map[String, Any] = {
@@ -108,7 +114,8 @@ class Observation(val name: String) {
/**
* Set the observed metrics and notify all waiting threads to resume.
*
- * @return `true` if all waiting threads were notified, `false` if otherwise.
+ * @return
+ * `true` if all waiting threads were notified, `false` if otherwise.
*/
private[sql] def setMetricsAndNotify(metrics: Row): Boolean = {
val metricsMap = metrics.getValuesMap(metrics.schema.map(_.name))
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
index fb4b4a6f37c8d..aa14115453aea 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
@@ -45,6 +45,7 @@ import org.apache.spark.util.ArrayImplicits._
*/
@Stable
object Row {
+
/**
* This method can be used to extract fields from a [[Row]] object in a pattern match. Example:
* {{{
@@ -83,9 +84,8 @@ object Row {
val empty = apply()
}
-
/**
- * Represents one row of output from a relational operator. Allows both generic access by ordinal,
+ * Represents one row of output from a relational operator. Allows both generic access by ordinal,
* which will incur boxing overhead for primitives, as well as native primitive access.
*
* It is invalid to use the native primitive interface to retrieve a value that is null, instead a
@@ -103,9 +103,9 @@ object Row {
* Row.fromSeq(Seq(value1, value2, ...))
* }}}
*
- * A value of a row can be accessed through both generic access by ordinal,
- * which will incur boxing overhead for primitives, as well as native primitive access.
- * An example of generic access by ordinal:
+ * A value of a row can be accessed through both generic access by ordinal, which will incur
+ * boxing overhead for primitives, as well as native primitive access. An example of generic
+ * access by ordinal:
* {{{
* import org.apache.spark.sql._
*
@@ -117,10 +117,9 @@ object Row {
* // fourthValue: Any = null
* }}}
*
- * For native primitive access, it is invalid to use the native primitive interface to retrieve
- * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a
- * value that might be null.
- * An example of native primitive access:
+ * For native primitive access, it is invalid to use the native primitive interface to retrieve a
+ * value that is null, instead a user must check `isNullAt` before attempting to retrieve a value
+ * that might be null. An example of native primitive access:
* {{{
* // using the row from the previous example.
* val firstValue = row.getInt(0)
@@ -143,6 +142,7 @@ object Row {
*/
@Stable
trait Row extends Serializable {
+
/** Number of elements in the Row. */
def size: Int = length
@@ -155,8 +155,8 @@ trait Row extends Serializable {
def schema: StructType = null
/**
- * Returns the value at position i. If the value is null, null is returned. The following
- * is a mapping between Spark SQL types and return types:
+ * Returns the value at position i. If the value is null, null is returned. The following is a
+ * mapping between Spark SQL types and return types:
*
* {{{
* BooleanType -> java.lang.Boolean
@@ -184,8 +184,8 @@ trait Row extends Serializable {
def apply(i: Int): Any = get(i)
/**
- * Returns the value at position i. If the value is null, null is returned. The following
- * is a mapping between Spark SQL types and return types:
+ * Returns the value at position i. If the value is null, null is returned. The following is a
+ * mapping between Spark SQL types and return types:
*
* {{{
* BooleanType -> java.lang.Boolean
@@ -218,106 +218,127 @@ trait Row extends Serializable {
/**
* Returns the value at position i as a primitive boolean.
*
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i)
/**
* Returns the value at position i as a primitive byte.
*
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
def getByte(i: Int): Byte = getAnyValAs[Byte](i)
/**
* Returns the value at position i as a primitive short.
*
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
def getShort(i: Int): Short = getAnyValAs[Short](i)
/**
* Returns the value at position i as a primitive int.
*
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
def getInt(i: Int): Int = getAnyValAs[Int](i)
/**
* Returns the value at position i as a primitive long.
*
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
def getLong(i: Int): Long = getAnyValAs[Long](i)
/**
- * Returns the value at position i as a primitive float.
- * Throws an exception if the type mismatches or if the value is null.
+ * Returns the value at position i as a primitive float. Throws an exception if the type
+ * mismatches or if the value is null.
*
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
def getFloat(i: Int): Float = getAnyValAs[Float](i)
/**
* Returns the value at position i as a primitive double.
*
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
def getDouble(i: Int): Double = getAnyValAs[Double](i)
/**
* Returns the value at position i as a String object.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getString(i: Int): String = getAs[String](i)
/**
* Returns the value at position i of decimal type as java.math.BigDecimal.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)
/**
* Returns the value at position i of date type as java.sql.Date.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i)
/**
* Returns the value at position i of date type as java.time.LocalDate.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getLocalDate(i: Int): java.time.LocalDate = getAs[java.time.LocalDate](i)
/**
* Returns the value at position i of date type as java.sql.Timestamp.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i)
/**
* Returns the value at position i of date type as java.time.Instant.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getInstant(i: Int): java.time.Instant = getAs[java.time.Instant](i)
/**
* Returns the value at position i of array type as a Scala Seq.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getSeq[T](i: Int): Seq[T] = {
getAs[scala.collection.Seq[T]](i) match {
@@ -334,7 +355,8 @@ trait Row extends Serializable {
/**
* Returns the value at position i of array type as `java.util.List`.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getList[T](i: Int): java.util.List[T] =
getSeq[T](i).asJava
@@ -342,14 +364,16 @@ trait Row extends Serializable {
/**
* Returns the value at position i of map type as a Scala Map.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i)
/**
* Returns the value at position i of array type as a `java.util.Map`.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getJavaMap[K, V](i: Int): java.util.Map[K, V] =
getMap[K, V](i).asJava
@@ -357,48 +381,56 @@ trait Row extends Serializable {
/**
* Returns the value at position i of struct type as a [[Row]] object.
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getStruct(i: Int): Row = getAs[Row](i)
/**
- * Returns the value at position i.
- * For primitive types if value is null it returns 'zero value' specific for primitive
- * i.e. 0 for Int - use isNullAt to ensure that value is not null
+ * Returns the value at position i. For primitive types if value is null it returns 'zero value'
+ * specific for primitive i.e. 0 for Int - use isNullAt to ensure that value is not null
*
- * @throws ClassCastException when data type does not match.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getAs[T](i: Int): T = get(i).asInstanceOf[T]
/**
- * Returns the value of a given fieldName.
- * For primitive types if value is null it returns 'zero value' specific for primitive
- * i.e. 0 for Int - use isNullAt to ensure that value is not null
+ * Returns the value of a given fieldName. For primitive types if value is null it returns 'zero
+ * value' specific for primitive i.e. 0 for Int - use isNullAt to ensure that value is not null
*
- * @throws UnsupportedOperationException when schema is not defined.
- * @throws IllegalArgumentException when fieldName do not exist.
- * @throws ClassCastException when data type does not match.
+ * @throws UnsupportedOperationException
+ * when schema is not defined.
+ * @throws IllegalArgumentException
+ * when fieldName do not exist.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))
/**
* Returns the index of a given field name.
*
- * @throws UnsupportedOperationException when schema is not defined.
- * @throws IllegalArgumentException when a field `name` does not exist.
+ * @throws UnsupportedOperationException
+ * when schema is not defined.
+ * @throws IllegalArgumentException
+ * when a field `name` does not exist.
*/
def fieldIndex(name: String): Int = {
throw DataTypeErrors.fieldIndexOnRowWithoutSchemaError(fieldName = name)
}
/**
- * Returns a Map consisting of names and values for the requested fieldNames
- * For primitive types if value is null it returns 'zero value' specific for primitive
- * i.e. 0 for Int - use isNullAt to ensure that value is not null
+ * Returns a Map consisting of names and values for the requested fieldNames For primitive types
+ * if value is null it returns 'zero value' specific for primitive i.e. 0 for Int - use isNullAt
+ * to ensure that value is not null
*
- * @throws UnsupportedOperationException when schema is not defined.
- * @throws IllegalArgumentException when fieldName do not exist.
- * @throws ClassCastException when data type does not match.
+ * @throws UnsupportedOperationException
+ * when schema is not defined.
+ * @throws IllegalArgumentException
+ * when fieldName do not exist.
+ * @throws ClassCastException
+ * when data type does not match.
*/
def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = {
fieldNames.map { name =>
@@ -445,24 +477,25 @@ trait Row extends Serializable {
o1 match {
case b1: Array[Byte] =>
if (!o2.isInstanceOf[Array[Byte]] ||
- !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
return false
}
case f1: Float if java.lang.Float.isNaN(f1) =>
- if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+ if (!o2.isInstanceOf[Float] || !java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
return false
}
case d1: Double if java.lang.Double.isNaN(d1) =>
- if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+ if (!o2.isInstanceOf[Double] || !java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
return false
}
case d1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] =>
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) {
return false
}
- case _ => if (o1 != o2) {
- return false
- }
+ case _ =>
+ if (o1 != o2) {
+ return false
+ }
}
}
i += 1
@@ -505,8 +538,8 @@ trait Row extends Serializable {
def mkString(sep: String): String = mkString("", sep, "")
/**
- * Displays all elements of this traversable or iterator in a string using
- * start, end, and separator strings.
+ * Displays all elements of this traversable or iterator in a string using start, end, and
+ * separator strings.
*/
def mkString(start: String, sep: String, end: String): String = {
val n = length
@@ -528,9 +561,12 @@ trait Row extends Serializable {
/**
* Returns the value at position i.
*
- * @throws UnsupportedOperationException when schema is not defined.
- * @throws ClassCastException when data type does not match.
- * @throws org.apache.spark.SparkRuntimeException when value is null.
+ * @throws UnsupportedOperationException
+ * when schema is not defined.
+ * @throws ClassCastException
+ * when data type does not match.
+ * @throws org.apache.spark.SparkRuntimeException
+ * when value is null.
*/
private def getAnyValAs[T <: AnyVal](i: Int): T =
if (isNullAt(i)) throw DataTypeErrors.valueIsNullError(i)
@@ -556,7 +592,8 @@ trait Row extends Serializable {
* Note that this only supports the data types that are also supported by
* [[org.apache.spark.sql.catalyst.encoders.RowEncoder]].
*
- * @return the JSON representation of the row.
+ * @return
+ * the JSON representation of the row.
*/
private[sql] def jsonValue: JValue = {
require(schema != null, "JSON serialization requires a non-null schema.")
@@ -598,13 +635,12 @@ trait Row extends Serializable {
case (s: Seq[_], ArrayType(elementType, _)) =>
iteratorToJsonArray(s.iterator, elementType)
case (m: Map[String @unchecked, _], MapType(StringType, valueType, _)) =>
- new JObject(m.toList.sortBy(_._1).map {
- case (k, v) => k -> toJson(v, valueType)
+ new JObject(m.toList.sortBy(_._1).map { case (k, v) =>
+ k -> toJson(v, valueType)
})
case (m: Map[_, _], MapType(keyType, valueType, _)) =>
- new JArray(m.iterator.map {
- case (k, v) =>
- new JObject("key" -> toJson(k, keyType) :: "value" -> toJson(v, valueType) :: Nil)
+ new JArray(m.iterator.map { case (k, v) =>
+ new JObject("key" -> toJson(k, keyType) :: "value" -> toJson(v, valueType) :: Nil)
}.toList)
case (row: Row, schema: StructType) =>
var n = 0
@@ -618,13 +654,13 @@ trait Row extends Serializable {
new JObject(elements.toList)
case (v: Any, udt: UserDefinedType[Any @unchecked]) =>
toJson(UDTUtils.toRow(v, udt), udt.sqlType)
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "FAILED_ROW_TO_JSON",
- messageParameters = Map(
- "value" -> toSQLValue(value.toString),
- "class" -> value.getClass.toString,
- "sqlType" -> toSQLType(dataType.toString))
- )
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "FAILED_ROW_TO_JSON",
+ messageParameters = Map(
+ "value" -> toSQLValue(value.toString),
+ "class" -> value.getClass.toString,
+ "sqlType" -> toSQLType(dataType.toString)))
}
toJson(this, schema)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
new file mode 100644
index 0000000000000..23a2774ebc3a5
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Stable
+
+/**
+ * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`.
+ *
+ * Options set here are automatically propagated to the Hadoop configuration during I/O.
+ *
+ * @since 2.0.0
+ */
+@Stable
+abstract class RuntimeConfig {
+
+ /**
+ * Sets the given Spark runtime configuration property.
+ *
+ * @since 2.0.0
+ */
+ def set(key: String, value: String): Unit
+
+ /**
+ * Sets the given Spark runtime configuration property.
+ *
+ * @since 2.0.0
+ */
+ def set(key: String, value: Boolean): Unit = {
+ set(key, value.toString)
+ }
+
+ /**
+ * Sets the given Spark runtime configuration property.
+ *
+ * @since 2.0.0
+ */
+ def set(key: String, value: Long): Unit = {
+ set(key, value.toString)
+ }
+
+ /**
+ * Returns the value of Spark runtime configuration property for the given key.
+ *
+ * @throws java.util.NoSuchElementException
+ * if the key is not set and does not have a default value
+ * @since 2.0.0
+ */
+ @throws[NoSuchElementException]("if the key is not set")
+ def get(key: String): String
+
+ /**
+ * Returns the value of Spark runtime configuration property for the given key.
+ *
+ * @since 2.0.0
+ */
+ def get(key: String, default: String): String
+
+ /**
+ * Returns all properties set in this conf.
+ *
+ * @since 2.0.0
+ */
+ def getAll: Map[String, String]
+
+ /**
+ * Returns the value of Spark runtime configuration property for the given key.
+ *
+ * @since 2.0.0
+ */
+ def getOption(key: String): Option[String]
+
+ /**
+ * Resets the configuration property for the given key.
+ *
+ * @since 2.0.0
+ */
+ def unset(key: String): Unit
+
+ /**
+ * Indicates whether the configuration property with the given key is modifiable in the current
+ * session.
+ *
+ * @return
+ * `true` if the configuration property is modifiable. For static SQL, Spark Core, invalid
+ * (not existing) and other non-modifiable configuration properties, the returned value is
+ * `false`.
+ * @since 2.4.0
+ */
+ def isModifiable(key: String): Boolean
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala
new file mode 100644
index 0000000000000..a0f51d30dc572
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala
@@ -0,0 +1,682 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api
+
+import scala.jdk.CollectionConverters._
+
+import _root_.java.util
+
+import org.apache.spark.annotation.Stable
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.catalog.{CatalogMetadata, Column, Database, Function, Table}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Catalog interface for Spark. To access this, use `SparkSession.catalog`.
+ *
+ * @since 2.0.0
+ */
+@Stable
+abstract class Catalog {
+
+ /**
+ * Returns the current database (namespace) in this session.
+ *
+ * @since 2.0.0
+ */
+ def currentDatabase: String
+
+ /**
+ * Sets the current database (namespace) in this session.
+ *
+ * @since 2.0.0
+ */
+ def setCurrentDatabase(dbName: String): Unit
+
+ /**
+ * Returns a list of databases (namespaces) available within the current catalog.
+ *
+ * @since 2.0.0
+ */
+ def listDatabases(): Dataset[Database]
+
+ /**
+ * Returns a list of databases (namespaces) which name match the specify pattern and available
+ * within the current catalog.
+ *
+ * @since 3.5.0
+ */
+ def listDatabases(pattern: String): Dataset[Database]
+
+ /**
+ * Returns a list of tables/views in the current database (namespace). This includes all
+ * temporary views.
+ *
+ * @since 2.0.0
+ */
+ def listTables(): Dataset[Table]
+
+ /**
+ * Returns a list of tables/views in the specified database (namespace) (the name can be
+ * qualified with catalog). This includes all temporary views.
+ *
+ * @since 2.0.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def listTables(dbName: String): Dataset[Table]
+
+ /**
+ * Returns a list of tables/views in the specified database (namespace) which name match the
+ * specify pattern (the name can be qualified with catalog). This includes all temporary views.
+ *
+ * @since 3.5.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def listTables(dbName: String, pattern: String): Dataset[Table]
+
+ /**
+ * Returns a list of functions registered in the current database (namespace). This includes all
+ * temporary functions.
+ *
+ * @since 2.0.0
+ */
+ def listFunctions(): Dataset[Function]
+
+ /**
+ * Returns a list of functions registered in the specified database (namespace) (the name can be
+ * qualified with catalog). This includes all built-in and temporary functions.
+ *
+ * @since 2.0.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def listFunctions(dbName: String): Dataset[Function]
+
+ /**
+ * Returns a list of functions registered in the specified database (namespace) which name match
+ * the specify pattern (the name can be qualified with catalog). This includes all built-in and
+ * temporary functions.
+ *
+ * @since 3.5.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def listFunctions(dbName: String, pattern: String): Dataset[Function]
+
+ /**
+ * Returns a list of columns for the given table/view or temporary view.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. It follows the same
+ * resolution rule with SQL: search for temp views first then table/views in the current
+ * database (namespace).
+ * @since 2.0.0
+ */
+ @throws[AnalysisException]("table does not exist")
+ def listColumns(tableName: String): Dataset[Column]
+
+ /**
+ * Returns a list of columns for the given table/view in the specified database under the Hive
+ * Metastore.
+ *
+ * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with
+ * qualified table/view name instead.
+ *
+ * @param dbName
+ * is an unqualified name that designates a database.
+ * @param tableName
+ * is an unqualified name that designates a table/view.
+ * @since 2.0.0
+ */
+ @throws[AnalysisException]("database or table does not exist")
+ def listColumns(dbName: String, tableName: String): Dataset[Column]
+
+ /**
+ * Get the database (namespace) with the specified name (can be qualified with catalog). This
+ * throws an AnalysisException when the database (namespace) cannot be found.
+ *
+ * @since 2.1.0
+ */
+ @throws[AnalysisException]("database does not exist")
+ def getDatabase(dbName: String): Database
+
+ /**
+ * Get the table or view with the specified name. This table can be a temporary view or a
+ * table/view. This throws an AnalysisException when no Table can be found.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. It follows the same
+ * resolution rule with SQL: search for temp views first then table/views in the current
+ * database (namespace).
+ * @since 2.1.0
+ */
+ @throws[AnalysisException]("table does not exist")
+ def getTable(tableName: String): Table
+
+ /**
+ * Get the table or view with the specified name in the specified database under the Hive
+ * Metastore. This throws an AnalysisException when no Table can be found.
+ *
+ * To get table/view in other catalogs, please use `getTable(tableName)` with qualified
+ * table/view name instead.
+ *
+ * @since 2.1.0
+ */
+ @throws[AnalysisException]("database or table does not exist")
+ def getTable(dbName: String, tableName: String): Table
+
+ /**
+ * Get the function with the specified name. This function can be a temporary function or a
+ * function. This throws an AnalysisException when the function cannot be found.
+ *
+ * @param functionName
+ * is either a qualified or unqualified name that designates a function. It follows the same
+ * resolution rule with SQL: search for built-in/temp functions first then functions in the
+ * current database (namespace).
+ * @since 2.1.0
+ */
+ @throws[AnalysisException]("function does not exist")
+ def getFunction(functionName: String): Function
+
+ /**
+ * Get the function with the specified name in the specified database under the Hive Metastore.
+ * This throws an AnalysisException when the function cannot be found.
+ *
+ * To get functions in other catalogs, please use `getFunction(functionName)` with qualified
+ * function name instead.
+ *
+ * @param dbName
+ * is an unqualified name that designates a database.
+ * @param functionName
+ * is an unqualified name that designates a function in the specified database
+ * @since 2.1.0
+ */
+ @throws[AnalysisException]("database or function does not exist")
+ def getFunction(dbName: String, functionName: String): Function
+
+ /**
+ * Check if the database (namespace) with the specified name exists (the name can be qualified
+ * with catalog).
+ *
+ * @since 2.1.0
+ */
+ def databaseExists(dbName: String): Boolean
+
+ /**
+ * Check if the table or view with the specified name exists. This can either be a temporary
+ * view or a table/view.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. It follows the same
+ * resolution rule with SQL: search for temp views first then table/views in the current
+ * database (namespace).
+ * @since 2.1.0
+ */
+ def tableExists(tableName: String): Boolean
+
+ /**
+ * Check if the table or view with the specified name exists in the specified database under the
+ * Hive Metastore.
+ *
+ * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with
+ * qualified table/view name instead.
+ *
+ * @param dbName
+ * is an unqualified name that designates a database.
+ * @param tableName
+ * is an unqualified name that designates a table.
+ * @since 2.1.0
+ */
+ def tableExists(dbName: String, tableName: String): Boolean
+
+ /**
+ * Check if the function with the specified name exists. This can either be a temporary function
+ * or a function.
+ *
+ * @param functionName
+ * is either a qualified or unqualified name that designates a function. It follows the same
+ * resolution rule with SQL: search for built-in/temp functions first then functions in the
+ * current database (namespace).
+ * @since 2.1.0
+ */
+ def functionExists(functionName: String): Boolean
+
+ /**
+ * Check if the function with the specified name exists in the specified database under the Hive
+ * Metastore.
+ *
+ * To check existence of functions in other catalogs, please use `functionExists(functionName)`
+ * with qualified function name instead.
+ *
+ * @param dbName
+ * is an unqualified name that designates a database.
+ * @param functionName
+ * is an unqualified name that designates a function.
+ * @since 2.1.0
+ */
+ def functionExists(dbName: String, functionName: String): Boolean
+
+ /**
+ * Creates a table from the given path and returns the corresponding DataFrame. It will use the
+ * default data source configured by spark.sql.sources.default.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.0.0
+ */
+ @deprecated("use createTable instead.", "2.2.0")
+ def createExternalTable(tableName: String, path: String): Dataset[Row] = {
+ createTable(tableName, path)
+ }
+
+ /**
+ * Creates a table from the given path and returns the corresponding DataFrame. It will use the
+ * default data source configured by spark.sql.sources.default.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.2.0
+ */
+ def createTable(tableName: String, path: String): Dataset[Row]
+
+ /**
+ * Creates a table from the given path based on a data source and returns the corresponding
+ * DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.0.0
+ */
+ @deprecated("use createTable instead.", "2.2.0")
+ def createExternalTable(tableName: String, path: String, source: String): Dataset[Row] = {
+ createTable(tableName, path, source)
+ }
+
+ /**
+ * Creates a table from the given path based on a data source and returns the corresponding
+ * DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.2.0
+ */
+ def createTable(tableName: String, path: String, source: String): Dataset[Row]
+
+ /**
+ * Creates a table from the given path based on a data source and a set of options. Then,
+ * returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.0.0
+ */
+ @deprecated("use createTable instead.", "2.2.0")
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: util.Map[String, String]): Dataset[Row] = {
+ createTable(tableName, source, options)
+ }
+
+ /**
+ * Creates a table based on the dataset in a data source and a set of options. Then, returns the
+ * corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.2.0
+ */
+ def createTable(
+ tableName: String,
+ source: String,
+ options: util.Map[String, String]): Dataset[Row] = {
+ createTable(tableName, source, options.asScala.toMap)
+ }
+
+ /**
+ * (Scala-specific) Creates a table from the given path based on a data source and a set of
+ * options. Then, returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.0.0
+ */
+ @deprecated("use createTable instead.", "2.2.0")
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: Map[String, String]): Dataset[Row] = {
+ createTable(tableName, source, options)
+ }
+
+ /**
+ * (Scala-specific) Creates a table based on the dataset in a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.2.0
+ */
+ def createTable(tableName: String, source: String, options: Map[String, String]): Dataset[Row]
+
+ /**
+ * Create a table from the given path based on a data source, a schema and a set of options.
+ * Then, returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.0.0
+ */
+ @deprecated("use createTable instead.", "2.2.0")
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: util.Map[String, String]): Dataset[Row] = {
+ createTable(tableName, source, schema, options)
+ }
+
+ /**
+ * Creates a table based on the dataset in a data source and a set of options. Then, returns the
+ * corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 3.1.0
+ */
+ def createTable(
+ tableName: String,
+ source: String,
+ description: String,
+ options: util.Map[String, String]): Dataset[Row] = {
+ createTable(
+ tableName,
+ source = source,
+ description = description,
+ options = options.asScala.toMap)
+ }
+
+ /**
+ * (Scala-specific) Creates a table based on the dataset in a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 3.1.0
+ */
+ def createTable(
+ tableName: String,
+ source: String,
+ description: String,
+ options: Map[String, String]): Dataset[Row]
+
+ /**
+ * Create a table based on the dataset in a data source, a schema and a set of options. Then,
+ * returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.2.0
+ */
+ def createTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: util.Map[String, String]): Dataset[Row] = {
+ createTable(tableName, source, schema, options.asScala.toMap)
+ }
+
+ /**
+ * (Scala-specific) Create a table from the given path based on a data source, a schema and a
+ * set of options. Then, returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.0.0
+ */
+ @deprecated("use createTable instead.", "2.2.0")
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): Dataset[Row] = {
+ createTable(tableName, source, schema, options)
+ }
+
+ /**
+ * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of
+ * options. Then, returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.2.0
+ */
+ def createTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): Dataset[Row]
+
+ /**
+ * Create a table based on the dataset in a data source, a schema and a set of options. Then,
+ * returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 3.1.0
+ */
+ def createTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ description: String,
+ options: util.Map[String, String]): Dataset[Row] = {
+ createTable(
+ tableName,
+ source = source,
+ schema = schema,
+ description = description,
+ options = options.asScala.toMap)
+ }
+
+ /**
+ * (Scala-specific) Create a table based on the dataset in a data source, a schema and a set of
+ * options. Then, returns the corresponding DataFrame.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 3.1.0
+ */
+ def createTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ description: String,
+ options: Map[String, String]): Dataset[Row]
+
+ /**
+ * Drops the local temporary view with the given view name in the catalog. If the view has been
+ * cached before, then it will also be uncached.
+ *
+ * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that
+ * created it, i.e. it will be automatically dropped when the session terminates. It's not tied
+ * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view.
+ *
+ * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean in
+ * Spark 2.1.
+ *
+ * @param viewName
+ * the name of the temporary view to be dropped.
+ * @return
+ * true if the view is dropped successfully, false otherwise.
+ * @since 2.0.0
+ */
+ def dropTempView(viewName: String): Boolean
+
+ /**
+ * Drops the global temporary view with the given view name in the catalog. If the view has been
+ * cached before, then it will also be uncached.
+ *
+ * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark
+ * application, i.e. it will be automatically dropped when the application terminates. It's tied
+ * to a system preserved database `global_temp`, and we must use the qualified name to refer a
+ * global temp view, e.g. `SELECT * FROM global_temp.view1`.
+ *
+ * @param viewName
+ * the unqualified name of the temporary view to be dropped.
+ * @return
+ * true if the view is dropped successfully, false otherwise.
+ * @since 2.1.0
+ */
+ def dropGlobalTempView(viewName: String): Boolean
+
+ /**
+ * Recovers all the partitions in the directory of a table and update the catalog. Only works
+ * with a partitioned table, and not a view.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table. If no database
+ * identifier is provided, it refers to a table in the current database.
+ * @since 2.1.1
+ */
+ def recoverPartitions(tableName: String): Unit
+
+ /**
+ * Returns true if the table is currently cached in-memory.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. If no database
+ * identifier is provided, it refers to a temporary view or a table/view in the current
+ * database.
+ * @since 2.0.0
+ */
+ def isCached(tableName: String): Boolean
+
+ /**
+ * Caches the specified table in-memory.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. If no database
+ * identifier is provided, it refers to a temporary view or a table/view in the current
+ * database.
+ * @since 2.0.0
+ */
+ def cacheTable(tableName: String): Unit
+
+ /**
+ * Caches the specified table with the given storage level.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. If no database
+ * identifier is provided, it refers to a temporary view or a table/view in the current
+ * database.
+ * @param storageLevel
+ * storage level to cache table.
+ * @since 2.3.0
+ */
+ def cacheTable(tableName: String, storageLevel: StorageLevel): Unit
+
+ /**
+ * Removes the specified table from the in-memory cache.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. If no database
+ * identifier is provided, it refers to a temporary view or a table/view in the current
+ * database.
+ * @since 2.0.0
+ */
+ def uncacheTable(tableName: String): Unit
+
+ /**
+ * Removes all cached tables from the in-memory cache.
+ *
+ * @since 2.0.0
+ */
+ def clearCache(): Unit
+
+ /**
+ * Invalidates and refreshes all the cached data and metadata of the given table. For
+ * performance reasons, Spark SQL or the external data source library it uses might cache
+ * certain metadata about a table, such as the location of blocks. When those change outside of
+ * Spark SQL, users should call this function to invalidate the cache.
+ *
+ * If this table is cached as an InMemoryRelation, drop the original cached version and make the
+ * new version cached lazily.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table/view. If no database
+ * identifier is provided, it refers to a temporary view or a table/view in the current
+ * database.
+ * @since 2.0.0
+ */
+ def refreshTable(tableName: String): Unit
+
+ /**
+ * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset`
+ * that contains the given data source path. Path matching is by checking for sub-directories,
+ * i.e. "/" would invalidate everything that is cached and "/test/parent" would invalidate
+ * everything that is a subdirectory of "/test/parent".
+ *
+ * @since 2.0.0
+ */
+ def refreshByPath(path: String): Unit
+
+ /**
+ * Returns the current catalog in this session.
+ *
+ * @since 3.4.0
+ */
+ def currentCatalog(): String
+
+ /**
+ * Sets the current catalog in this session.
+ *
+ * @since 3.4.0
+ */
+ def setCurrentCatalog(catalogName: String): Unit
+
+ /**
+ * Returns a list of catalogs available in this session.
+ *
+ * @since 3.4.0
+ */
+ def listCatalogs(): Dataset[CatalogMetadata]
+
+ /**
+ * Returns a list of catalogs which name match the specify pattern and available in this
+ * session.
+ *
+ * @since 3.5.0
+ */
+ def listCatalogs(pattern: String): Dataset[CatalogMetadata]
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala
index 7400f90992d8f..ef6cc64c058a4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala
@@ -30,32 +30,32 @@ import org.apache.spark.util.ArrayImplicits._
* @since 1.3.1
*/
@Stable
-abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
+abstract class DataFrameNaFunctions {
/**
* Returns a new `DataFrame` that drops rows containing any null or NaN values.
*
* @since 1.3.1
*/
- def drop(): DS[Row] = drop("any")
+ def drop(): Dataset[Row] = drop("any")
/**
* Returns a new `DataFrame` that drops rows containing null or NaN values.
*
- * If `how` is "any", then drop rows containing any null or NaN values.
- * If `how` is "all", then drop rows only if every column is null or NaN for that row.
+ * If `how` is "any", then drop rows containing any null or NaN values. If `how` is "all", then
+ * drop rows only if every column is null or NaN for that row.
*
* @since 1.3.1
*/
- def drop(how: String): DS[Row] = drop(toMinNonNulls(how))
+ def drop(how: String): Dataset[Row] = drop(toMinNonNulls(how))
/**
- * Returns a new `DataFrame` that drops rows containing any null or NaN values
- * in the specified columns.
+ * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified
+ * columns.
*
* @since 1.3.1
*/
- def drop(cols: Array[String]): DS[Row] = drop(cols.toImmutableArraySeq)
+ def drop(cols: Array[String]): Dataset[Row] = drop(cols.toImmutableArraySeq)
/**
* (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values
@@ -63,54 +63,54 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 1.3.1
*/
- def drop(cols: Seq[String]): DS[Row] = drop(cols.size, cols)
+ def drop(cols: Seq[String]): Dataset[Row] = drop(cols.size, cols)
/**
- * Returns a new `DataFrame` that drops rows containing null or NaN values
- * in the specified columns.
+ * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified
+ * columns.
*
* If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
* If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
*
* @since 1.3.1
*/
- def drop(how: String, cols: Array[String]): DS[Row] = drop(how, cols.toImmutableArraySeq)
+ def drop(how: String, cols: Array[String]): Dataset[Row] = drop(how, cols.toImmutableArraySeq)
/**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values
- * in the specified columns.
+ * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in
+ * the specified columns.
*
* If `how` is "any", then drop rows containing any null or NaN values in the specified columns.
* If `how` is "all", then drop rows only if every specified column is null or NaN for that row.
*
* @since 1.3.1
*/
- def drop(how: String, cols: Seq[String]): DS[Row] = drop(toMinNonNulls(how), cols)
+ def drop(how: String, cols: Seq[String]): Dataset[Row] = drop(toMinNonNulls(how), cols)
/**
- * Returns a new `DataFrame` that drops rows containing
- * less than `minNonNulls` non-null and non-NaN values.
+ * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and
+ * non-NaN values.
*
* @since 1.3.1
*/
- def drop(minNonNulls: Int): DS[Row] = drop(Option(minNonNulls))
+ def drop(minNonNulls: Int): Dataset[Row] = drop(Option(minNonNulls))
/**
- * Returns a new `DataFrame` that drops rows containing
- * less than `minNonNulls` non-null and non-NaN values in the specified columns.
+ * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and
+ * non-NaN values in the specified columns.
*
* @since 1.3.1
*/
- def drop(minNonNulls: Int, cols: Array[String]): DS[Row] =
+ def drop(minNonNulls: Int, cols: Array[String]): Dataset[Row] =
drop(minNonNulls, cols.toImmutableArraySeq)
/**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than
- * `minNonNulls` non-null and non-NaN values in the specified columns.
+ * (Scala-specific) Returns a new `DataFrame` that drops rows containing less than `minNonNulls`
+ * non-null and non-NaN values in the specified columns.
*
* @since 1.3.1
*/
- def drop(minNonNulls: Int, cols: Seq[String]): DS[Row] = drop(Option(minNonNulls), cols)
+ def drop(minNonNulls: Int, cols: Seq[String]): Dataset[Row] = drop(Option(minNonNulls), cols)
private def toMinNonNulls(how: String): Option[Int] = {
how.toLowerCase(util.Locale.ROOT) match {
@@ -120,45 +120,46 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
}
}
- protected def drop(minNonNulls: Option[Int]): DS[Row]
+ protected def drop(minNonNulls: Option[Int]): Dataset[Row]
- protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DS[Row]
+ protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row]
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
*
* @since 2.2.0
*/
- def fill(value: Long): DS[Row]
+ def fill(value: Long): Dataset[Row]
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* @since 1.3.1
*/
- def fill(value: Double): DS[Row]
+ def fill(value: Double): Dataset[Row]
/**
* Returns a new `DataFrame` that replaces null values in string columns with `value`.
*
* @since 1.3.1
*/
- def fill(value: String): DS[Row]
+ def fill(value: String): Dataset[Row]
/**
- * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
- * If a specified column is not a numeric column, it is ignored.
+ * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a
+ * specified column is not a numeric column, it is ignored.
*
* @since 2.2.0
*/
- def fill(value: Long, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq)
+ def fill(value: Long, cols: Array[String]): Dataset[Row] = fill(value, cols.toImmutableArraySeq)
/**
- * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
- * If a specified column is not a numeric column, it is ignored.
+ * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a
+ * specified column is not a numeric column, it is ignored.
*
* @since 1.3.1
*/
- def fill(value: Double, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq)
+ def fill(value: Double, cols: Array[String]): Dataset[Row] =
+ fill(value, cols.toImmutableArraySeq)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
@@ -166,7 +167,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 2.2.0
*/
- def fill(value: Long, cols: Seq[String]): DS[Row]
+ def fill(value: Long, cols: Seq[String]): Dataset[Row]
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
@@ -174,58 +175,58 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 1.3.1
*/
- def fill(value: Double, cols: Seq[String]): DS[Row]
-
+ def fill(value: Double, cols: Seq[String]): Dataset[Row]
/**
- * Returns a new `DataFrame` that replaces null values in specified string columns.
- * If a specified column is not a string column, it is ignored.
+ * Returns a new `DataFrame` that replaces null values in specified string columns. If a
+ * specified column is not a string column, it is ignored.
*
* @since 1.3.1
*/
- def fill(value: String, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq)
+ def fill(value: String, cols: Array[String]): Dataset[Row] =
+ fill(value, cols.toImmutableArraySeq)
/**
- * (Scala-specific) Returns a new `DataFrame` that replaces null values in
- * specified string columns. If a specified column is not a string column, it is ignored.
+ * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified string
+ * columns. If a specified column is not a string column, it is ignored.
*
* @since 1.3.1
*/
- def fill(value: String, cols: Seq[String]): DS[Row]
+ def fill(value: String, cols: Seq[String]): Dataset[Row]
/**
* Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
*
* @since 2.3.0
*/
- def fill(value: Boolean): DS[Row]
+ def fill(value: Boolean): Dataset[Row]
/**
- * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
- * boolean columns. If a specified column is not a boolean column, it is ignored.
+ * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean
+ * columns. If a specified column is not a boolean column, it is ignored.
*
* @since 2.3.0
*/
- def fill(value: Boolean, cols: Seq[String]): DS[Row]
+ def fill(value: Boolean, cols: Seq[String]): Dataset[Row]
/**
- * Returns a new `DataFrame` that replaces null values in specified boolean columns.
- * If a specified column is not a boolean column, it is ignored.
+ * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a
+ * specified column is not a boolean column, it is ignored.
*
* @since 2.3.0
*/
- def fill(value: Boolean, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq)
+ def fill(value: Boolean, cols: Array[String]): Dataset[Row] =
+ fill(value, cols.toImmutableArraySeq)
/**
* Returns a new `DataFrame` that replaces null values.
*
- * The key of the map is the column name, and the value of the map is the replacement value.
- * The value must be of the following type:
- * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`.
- * Replacement values are cast to the column data type.
+ * The key of the map is the column name, and the value of the map is the replacement value. The
+ * value must be of the following type: `Integer`, `Long`, `Float`, `Double`, `String`,
+ * `Boolean`. Replacement values are cast to the column data type.
*
- * For example, the following replaces null values in column "A" with string "unknown", and
- * null values in column "B" with numeric value 1.0.
+ * For example, the following replaces null values in column "A" with string "unknown", and null
+ * values in column "B" with numeric value 1.0.
* {{{
* import com.google.common.collect.ImmutableMap;
* df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
@@ -233,17 +234,17 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 1.3.1
*/
- def fill(valueMap: util.Map[String, Any]): DS[Row] = fillMap(valueMap.asScala.toSeq)
+ def fill(valueMap: util.Map[String, Any]): Dataset[Row] = fillMap(valueMap.asScala.toSeq)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values.
*
- * The key of the map is the column name, and the value of the map is the replacement value.
- * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`.
+ * The key of the map is the column name, and the value of the map is the replacement value. The
+ * value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`.
* Replacement values are cast to the column data type.
*
- * For example, the following replaces null values in column "A" with string "unknown", and
- * null values in column "B" with numeric value 1.0.
+ * For example, the following replaces null values in column "A" with string "unknown", and null
+ * values in column "B" with numeric value 1.0.
* {{{
* df.na.fill(Map(
* "A" -> "unknown",
@@ -253,9 +254,9 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 1.3.1
*/
- def fill(valueMap: Map[String, Any]): DS[Row] = fillMap(valueMap.toSeq)
+ def fill(valueMap: Map[String, Any]): Dataset[Row] = fillMap(valueMap.toSeq)
- protected def fillMap(values: Seq[(String, Any)]): DS[Row]
+ protected def fillMap(values: Seq[(String, Any)]): Dataset[Row]
/**
* Replaces values matching keys in `replacement` map with the corresponding values.
@@ -273,15 +274,16 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
* df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
* }}}
*
- * @param col name of the column to apply the value replacement. If `col` is "*",
- * replacement is applied on all string, numeric or boolean columns.
- * @param replacement value replacement map. Key and value of `replacement` map must have
- * the same type, and can only be doubles, strings or booleans.
- * The map value can have nulls.
+ * @param col
+ * name of the column to apply the value replacement. If `col` is "*", replacement is applied
+ * on all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
*
* @since 1.3.1
*/
- def replace[T](col: String, replacement: util.Map[T, T]): DS[Row] = {
+ def replace[T](col: String, replacement: util.Map[T, T]): Dataset[Row] = {
replace[T](col, replacement.asScala.toMap)
}
@@ -298,15 +300,16 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
* df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
* }}}
*
- * @param cols list of columns to apply the value replacement. If `col` is "*",
- * replacement is applied on all string, numeric or boolean columns.
- * @param replacement value replacement map. Key and value of `replacement` map must have
- * the same type, and can only be doubles, strings or booleans.
- * The map value can have nulls.
+ * @param cols
+ * list of columns to apply the value replacement. If `col` is "*", replacement is applied on
+ * all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
*
* @since 1.3.1
*/
- def replace[T](cols: Array[String], replacement: util.Map[T, T]): DS[Row] = {
+ def replace[T](cols: Array[String], replacement: util.Map[T, T]): Dataset[Row] = {
replace(cols.toImmutableArraySeq, replacement.asScala.toMap)
}
@@ -324,15 +327,16 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
* df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
* }}}
*
- * @param col name of the column to apply the value replacement. If `col` is "*",
- * replacement is applied on all string, numeric or boolean columns.
- * @param replacement value replacement map. Key and value of `replacement` map must have
- * the same type, and can only be doubles, strings or booleans.
- * The map value can have nulls.
+ * @param col
+ * name of the column to apply the value replacement. If `col` is "*", replacement is applied
+ * on all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
*
* @since 1.3.1
*/
- def replace[T](col: String, replacement: Map[T, T]): DS[Row]
+ def replace[T](col: String, replacement: Map[T, T]): Dataset[Row]
/**
* (Scala-specific) Replaces values matching keys in `replacement` map.
@@ -345,13 +349,14 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
* df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"));
* }}}
*
- * @param cols list of columns to apply the value replacement. If `col` is "*",
- * replacement is applied on all string, numeric or boolean columns.
- * @param replacement value replacement map. Key and value of `replacement` map must have
- * the same type, and can only be doubles, strings or booleans.
- * The map value can have nulls.
+ * @param cols
+ * list of columns to apply the value replacement. If `col` is "*", replacement is applied on
+ * all string, numeric or boolean columns.
+ * @param replacement
+ * value replacement map. Key and value of `replacement` map must have the same type, and can
+ * only be doubles, strings or booleans. The map value can have nulls.
*
* @since 1.3.1
*/
- def replace[T](cols: Seq[String], replacement: Map[T, T]): DS[Row]
+ def replace[T](cols: Seq[String], replacement: Map[T, T]): Dataset[Row]
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala
new file mode 100644
index 0000000000000..c101c52fd0662
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.api
+
+import scala.jdk.CollectionConverters._
+
+import _root_.java.util
+
+import org.apache.spark.annotation.Stable
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils}
+import org.apache.spark.sql.errors.DataTypeErrors
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems,
+ * key-value stores, etc). Use `SparkSession.read` to access this.
+ *
+ * @since 1.4.0
+ */
+@Stable
+abstract class DataFrameReader {
+ type DS[U] <: Dataset[U]
+
+ /**
+ * Specifies the input data source format.
+ *
+ * @since 1.4.0
+ */
+ def format(source: String): this.type = {
+ this.source = source
+ this
+ }
+
+ /**
+ * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema
+ * automatically from data. By specifying the schema here, the underlying data source can skip
+ * the schema inference step, and thus speed up data loading.
+ *
+ * @since 1.4.0
+ */
+ def schema(schema: StructType): this.type = {
+ if (schema != null) {
+ val replaced = SparkCharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+ this.userSpecifiedSchema = Option(replaced)
+ validateSingleVariantColumn()
+ }
+ this
+ }
+
+ /**
+ * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON)
+ * can infer the input schema automatically from data. By specifying the schema here, the
+ * underlying data source can skip the schema inference step, and thus speed up data loading.
+ *
+ * {{{
+ * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv")
+ * }}}
+ *
+ * @since 2.3.0
+ */
+ def schema(schemaString: String): this.type = schema(StructType.fromDDL(schemaString))
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 1.4.0
+ */
+ def option(key: String, value: String): this.type = {
+ this.extraOptions = this.extraOptions + (key -> value)
+ validateSingleVariantColumn()
+ this
+ }
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Boolean): this.type = option(key, value.toString)
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Long): this.type = option(key, value.toString)
+
+ /**
+ * Adds an input option for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 2.0.0
+ */
+ def option(key: String, value: Double): this.type = option(key, value.toString)
+
+ /**
+ * (Scala-specific) Adds input options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 1.4.0
+ */
+ def options(options: scala.collection.Map[String, String]): this.type = {
+ this.extraOptions ++= options
+ validateSingleVariantColumn()
+ this
+ }
+
+ /**
+ * Adds input options for the underlying data source.
+ *
+ * All options are maintained in a case-insensitive way in terms of key names. If a new option
+ * has the same key case-insensitively, it will override the existing option.
+ *
+ * @since 1.4.0
+ */
+ def options(opts: util.Map[String, String]): this.type = options(opts.asScala)
+
+ /**
+ * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
+ * key-value stores).
+ *
+ * @since 1.4.0
+ */
+ def load(): Dataset[Row]
+
+ /**
+ * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a
+ * local or distributed file system).
+ *
+ * @since 1.4.0
+ */
+ def load(path: String): Dataset[Row]
+
+ /**
+ * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if
+ * the source is a HadoopFsRelationProvider.
+ *
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def load(paths: String*): Dataset[Row]
+
+ /**
+ * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
+ * table and connection properties.
+ *
+ * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
+ * in
+ * Data Source Option in the version you use.
+ *
+ * @since 1.4.0
+ */
+ def jdbc(url: String, table: String, properties: util.Properties): Dataset[Row] = {
+ assertNoSpecifiedSchema("jdbc")
+ // properties should override settings in extraOptions.
+ this.extraOptions ++= properties.asScala
+ // explicit url and dbtable should override all
+ this.extraOptions ++= Seq("url" -> url, "dbtable" -> table)
+ format("jdbc").load()
+ }
+
+ // scalastyle:off line.size.limit
+ /**
+ * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
+ * table. Partitions of the table will be retrieved in parallel based on the parameters passed
+ * to this function.
+ *
+ * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
+ * your external database systems.
+ *
+ * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
+ * in
+ * Data Source Option in the version you use.
+ *
+ * @param table
+ * Name of the table in the external database.
+ * @param columnName
+ * Alias of `partitionColumn` option. Refer to `partitionColumn` in
+ * Data Source Option in the version you use.
+ * @param connectionProperties
+ * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
+ * a "user" and "password" property should be included. "fetchsize" can be used to control the
+ * number of rows per fetch and "queryTimeout" can be used to wait for a Statement object to
+ * execute to the given number of seconds.
+ * @since 1.4.0
+ */
+ // scalastyle:on line.size.limit
+ def jdbc(
+ url: String,
+ table: String,
+ columnName: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ connectionProperties: util.Properties): Dataset[Row] = {
+ // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions.
+ this.extraOptions ++= Map(
+ "partitionColumn" -> columnName,
+ "lowerBound" -> lowerBound.toString,
+ "upperBound" -> upperBound.toString,
+ "numPartitions" -> numPartitions.toString)
+ jdbc(url, table, connectionProperties)
+ }
+
+ /**
+ * Construct a `DataFrame` representing the database table accessible via JDBC URL url named
+ * table using connection properties. The `predicates` parameter gives a list expressions
+ * suitable for inclusion in WHERE clauses; each one defines one partition of the `DataFrame`.
+ *
+ * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
+ * your external database systems.
+ *
+ * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC
+ * in
+ * Data Source Option in the version you use.
+ *
+ * @param table
+ * Name of the table in the external database.
+ * @param predicates
+ * Condition in the where clause for each partition.
+ * @param connectionProperties
+ * JDBC database connection arguments, a list of arbitrary string tag/value. Normally at least
+ * a "user" and "password" property should be included. "fetchsize" can be used to control the
+ * number of rows per fetch.
+ * @since 1.4.0
+ */
+ def jdbc(
+ url: String,
+ table: String,
+ predicates: Array[String],
+ connectionProperties: util.Properties): Dataset[Row]
+
+ /**
+ * Loads a JSON file and returns the results as a `DataFrame`.
+ *
+ * See the documentation on the overloaded `json()` method with varargs for more details.
+ *
+ * @since 1.4.0
+ */
+ def json(path: String): Dataset[Row] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ json(Seq(path): _*)
+ }
+
+ /**
+ * Loads JSON files and returns the results as a `DataFrame`.
+ *
+ * JSON Lines (newline-delimited JSON) is supported by
+ * default. For JSON (one record per file), set the `multiLine` option to true.
+ *
+ * This function goes through the input once to determine the input schema. If you know the
+ * schema in advance, use the version that specifies the schema to avoid the extra scan.
+ *
+ * You can find the JSON-specific options for reading JSON files in
+ * Data Source Option in the version you use.
+ *
+ * @since 2.0.0
+ */
+ @scala.annotation.varargs
+ def json(paths: String*): Dataset[Row] = {
+ validateJsonSchema()
+ format("json").load(paths: _*)
+ }
+
+ /**
+ * Loads a `Dataset[String]` storing JSON objects (JSON Lines
+ * text format or newline-delimited JSON) and returns the result as a `DataFrame`.
+ *
+ * Unless the schema is specified using `schema` function, this function goes through the input
+ * once to determine the input schema.
+ *
+ * @param jsonDataset
+ * input Dataset with one JSON object per record
+ * @since 2.2.0
+ */
+ def json(jsonDataset: DS[String]): Dataset[Row]
+
+ /**
+ * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
+ * overloaded `csv()` method for more details.
+ *
+ * @since 2.0.0
+ */
+ def csv(path: String): Dataset[Row] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ csv(Seq(path): _*)
+ }
+
+ /**
+ * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
+ * this function goes through the input once to determine the input schema.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is disabled,
+ * it determines the columns as string types and it reads only the first line to determine the
+ * names and the number of fields.
+ *
+ * If the enforceSchema is set to `false`, only the CSV header in the first line is checked to
+ * conform specified or inferred schema.
+ *
+ * @note
+ * if `header` option is set to `true` when calling this API, all lines same with the header
+ * will be removed if exists.
+ *
+ * @param csvDataset
+ * input Dataset with one CSV row per record
+ * @since 2.2.0
+ */
+ def csv(csvDataset: DS[String]): Dataset[Row]
+
+ /**
+ * Loads CSV files and returns the result as a `DataFrame`.
+ *
+ * This function will go through the input once to determine the input schema if `inferSchema`
+ * is enabled. To avoid going through the entire data once, disable `inferSchema` option or
+ * specify the schema explicitly using `schema`.
+ *
+ * You can find the CSV-specific options for reading CSV files in
+ * Data Source Option in the version you use.
+ *
+ * @since 2.0.0
+ */
+ @scala.annotation.varargs
+ def csv(paths: String*): Dataset[Row] = format("csv").load(paths: _*)
+
+ /**
+ * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other
+ * overloaded `xml()` method for more details.
+ *
+ * @since 4.0.0
+ */
+ def xml(path: String): Dataset[Row] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ xml(Seq(path): _*)
+ }
+
+ /**
+ * Loads XML files and returns the result as a `DataFrame`.
+ *
+ * This function will go through the input once to determine the input schema if `inferSchema`
+ * is enabled. To avoid going through the entire data once, disable `inferSchema` option or
+ * specify the schema explicitly using `schema`.
+ *
+ * You can find the XML-specific options for reading XML files in
+ * Data Source Option in the version you use.
+ *
+ * @since 4.0.0
+ */
+ @scala.annotation.varargs
+ def xml(paths: String*): Dataset[Row] = {
+ validateXmlSchema()
+ format("xml").load(paths: _*)
+ }
+
+ /**
+ * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`.
+ *
+ * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
+ * this function goes through the input once to determine the input schema.
+ *
+ * @param xmlDataset
+ * input Dataset with one XML object per record
+ * @since 4.0.0
+ */
+ def xml(xmlDataset: DS[String]): Dataset[Row]
+
+ /**
+ * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the
+ * other overloaded `parquet()` method for more details.
+ *
+ * @since 2.0.0
+ */
+ def parquet(path: String): Dataset[Row] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ parquet(Seq(path): _*)
+ }
+
+ /**
+ * Loads a Parquet file, returning the result as a `DataFrame`.
+ *
+ * Parquet-specific option(s) for reading Parquet files can be found in Data
+ * Source Option in the version you use.
+ *
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def parquet(paths: String*): Dataset[Row] = format("parquet").load(paths: _*)
+
+ /**
+ * Loads an ORC file and returns the result as a `DataFrame`.
+ *
+ * @param path
+ * input path
+ * @since 1.5.0
+ */
+ def orc(path: String): Dataset[Row] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ orc(Seq(path): _*)
+ }
+
+ /**
+ * Loads ORC files and returns the result as a `DataFrame`.
+ *
+ * ORC-specific option(s) for reading ORC files can be found in Data
+ * Source Option in the version you use.
+ *
+ * @param paths
+ * input paths
+ * @since 2.0.0
+ */
+ @scala.annotation.varargs
+ def orc(paths: String*): Dataset[Row] = format("orc").load(paths: _*)
+
+ /**
+ * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch
+ * reading and the returned DataFrame is the batch scan query plan of this table. If it's a
+ * view, the returned DataFrame is simply the query plan of the view, which can either be a
+ * batch or streaming query plan.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table or view. If a database is
+ * specified, it identifies the table/view from the database. Otherwise, it first attempts to
+ * find a temporary view with the given name and then match the table/view from the current
+ * database. Note that, the global temporary view database is also valid here.
+ * @since 1.4.0
+ */
+ def table(tableName: String): Dataset[Row]
+
+ /**
+ * Loads text files and returns a `DataFrame` whose schema starts with a string column named
+ * "value", and followed by partitioned columns if there are any. See the documentation on the
+ * other overloaded `text()` method for more details.
+ *
+ * @since 2.0.0
+ */
+ def text(path: String): Dataset[Row] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ text(Seq(path): _*)
+ }
+
+ /**
+ * Loads text files and returns a `DataFrame` whose schema starts with a string column named
+ * "value", and followed by partitioned columns if there are any. The text files must be encoded
+ * as UTF-8.
+ *
+ * By default, each line in the text files is a new row in the resulting DataFrame. For example:
+ * {{{
+ * // Scala:
+ * spark.read.text("/path/to/spark/README.md")
+ *
+ * // Java:
+ * spark.read().text("/path/to/spark/README.md")
+ * }}}
+ *
+ * You can find the text-specific options for reading text files in
+ * Data Source Option in the version you use.
+ *
+ * @param paths
+ * input paths
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def text(paths: String*): Dataset[Row] = format("text").load(paths: _*)
+
+ /**
+ * Loads text files and returns a [[Dataset]] of String. See the documentation on the other
+ * overloaded `textFile()` method for more details.
+ * @since 2.0.0
+ */
+ def textFile(path: String): Dataset[String] = {
+ // This method ensures that calls that explicit need single argument works, see SPARK-16009
+ textFile(Seq(path): _*)
+ }
+
+ /**
+ * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
+ * contains a single string column named "value". The text files must be encoded as UTF-8.
+ *
+ * If the directory structure of the text files contains partitioning information, those are
+ * ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
+ *
+ * By default, each line in the text files is a new row in the resulting DataFrame. For example:
+ * {{{
+ * // Scala:
+ * spark.read.textFile("/path/to/spark/README.md")
+ *
+ * // Java:
+ * spark.read().textFile("/path/to/spark/README.md")
+ * }}}
+ *
+ * You can set the text-specific options as specified in `DataFrameReader.text`.
+ *
+ * @param paths
+ * input path
+ * @since 2.0.0
+ */
+ @scala.annotation.varargs
+ def textFile(paths: String*): Dataset[String] = {
+ assertNoSpecifiedSchema("textFile")
+ text(paths: _*).select("value").as(StringEncoder)
+ }
+
+ /**
+ * A convenient function for schema validation in APIs.
+ */
+ protected def assertNoSpecifiedSchema(operation: String): Unit = {
+ if (userSpecifiedSchema.nonEmpty) {
+ throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation)
+ }
+ }
+
+ /**
+ * Ensure that the `singleVariantColumn` option cannot be used if there is also a user specified
+ * schema.
+ */
+ protected def validateSingleVariantColumn(): Unit = ()
+
+ protected def validateJsonSchema(): Unit = ()
+
+ protected def validateXmlSchema(): Unit = ()
+
+ ///////////////////////////////////////////////////////////////////////////////////////
+ // Builder pattern config options
+ ///////////////////////////////////////////////////////////////////////////////////////
+
+ protected var source: String = _
+
+ protected var userSpecifiedSchema: Option[StructType] = None
+
+ protected var extraOptions: CaseInsensitiveMap[String] = CaseInsensitiveMap[String](Map.empty)
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala
index c3ecc7b90d5b4..ae7c256b30ace 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala
@@ -34,38 +34,41 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
* @since 1.4.0
*/
@Stable
-abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
- protected def df: DS[Row]
+abstract class DataFrameStatFunctions {
+ protected def df: Dataset[Row]
/**
* Calculates the approximate quantiles of a numerical column of a DataFrame.
*
- * The result of this algorithm has the following deterministic bound:
- * If the DataFrame has N elements and if we request the quantile at probability `p` up to error
- * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank
- * of `x` is close to (p * N).
- * More precisely,
+ * The result of this algorithm has the following deterministic bound: If the DataFrame has N
+ * elements and if we request the quantile at probability `p` up to error `err`, then the
+ * algorithm will return a sample `x` from the DataFrame so that the *exact* rank of `x` is
+ * close to (p * N). More precisely,
*
* {{{
* floor((p - err) * N) <= rank(x) <= ceil((p + err) * N)
* }}}
*
* This method implements a variation of the Greenwald-Khanna algorithm (with some speed
- * optimizations).
- * The algorithm was first present in
- * Space-efficient Online Computation of Quantile Summaries by Greenwald and Khanna.
- *
- * @param col the name of the numerical column
- * @param probabilities a list of quantile probabilities
- * Each number must belong to [0, 1].
- * For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError The relative target precision to achieve (greater than or equal to 0).
- * If set to zero, the exact quantiles are computed, which could be very expensive.
- * Note that values greater than 1 are accepted but give the same result as 1.
- * @return the approximate quantiles at the given probabilities
- *
- * @note null and NaN values will be removed from the numerical column before calculation. If
- * the dataframe is empty or the column only contains null or NaN, an empty array is returned.
+ * optimizations). The algorithm was first present in Space-efficient Online Computation of Quantile
+ * Summaries by Greenwald and Khanna.
+ *
+ * @param col
+ * the name of the numerical column
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+ * exact quantiles are computed, which could be very expensive. Note that values greater than
+ * 1 are accepted but give the same result as 1.
+ * @return
+ * the approximate quantiles at the given probabilities
+ *
+ * @note
+ * null and NaN values will be removed from the numerical column before calculation. If the
+ * dataframe is empty or the column only contains null or NaN, an empty array is returned.
*
* @since 2.0.0
*/
@@ -78,19 +81,24 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Calculates the approximate quantiles of numerical columns of a DataFrame.
- * @see `approxQuantile(col:Str* approxQuantile)` for detailed description.
- *
- * @param cols the names of the numerical columns
- * @param probabilities a list of quantile probabilities
- * Each number must belong to [0, 1].
- * For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
- * @param relativeError The relative target precision to achieve (greater than or equal to 0).
- * If set to zero, the exact quantiles are computed, which could be very expensive.
- * Note that values greater than 1 are accepted but give the same result as 1.
- * @return the approximate quantiles at the given probabilities of each column
- *
- * @note null and NaN values will be ignored in numerical columns before calculation. For
- * columns only containing null or NaN values, an empty array is returned.
+ * @see
+ * `approxQuantile(col:Str* approxQuantile)` for detailed description.
+ *
+ * @param cols
+ * the names of the numerical columns
+ * @param probabilities
+ * a list of quantile probabilities Each number must belong to [0, 1]. For example 0 is the
+ * minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError
+ * The relative target precision to achieve (greater than or equal to 0). If set to zero, the
+ * exact quantiles are computed, which could be very expensive. Note that values greater than
+ * 1 are accepted but give the same result as 1.
+ * @return
+ * the approximate quantiles at the given probabilities of each column
+ *
+ * @note
+ * null and NaN values will be ignored in numerical columns before calculation. For columns
+ * only containing null or NaN values, an empty array is returned.
*
* @since 2.2.0
*/
@@ -102,9 +110,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Calculate the sample covariance of two numerical columns of a DataFrame.
*
- * @param col1 the name of the first column
- * @param col2 the name of the second column
- * @return the covariance of the two columns.
+ * @param col1
+ * the name of the first column
+ * @param col2
+ * the name of the second column
+ * @return
+ * the covariance of the two columns.
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
@@ -121,9 +132,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
- * @param col1 the name of the column
- * @param col2 the name of the column to calculate the correlation against
- * @return The Pearson Correlation Coefficient as a Double.
+ * @param col1
+ * the name of the column
+ * @param col2
+ * the name of the column to calculate the correlation against
+ * @return
+ * The Pearson Correlation Coefficient as a Double.
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
@@ -138,9 +152,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Calculates the Pearson Correlation Coefficient of two columns of a DataFrame.
*
- * @param col1 the name of the column
- * @param col2 the name of the column to calculate the correlation against
- * @return The Pearson Correlation Coefficient as a Double.
+ * @param col1
+ * the name of the column
+ * @param col2
+ * the name of the column to calculate the correlation against
+ * @return
+ * The Pearson Correlation Coefficient as a Double.
*
* {{{
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
@@ -159,14 +176,15 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
* The first column of each row will be the distinct values of `col1` and the column names will
* be the distinct values of `col2`. The name of the first column will be `col1_col2`. Counts
* will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts.
- * Null elements will be replaced by "null", and back ticks will be dropped from elements if they
- * exist.
+ * Null elements will be replaced by "null", and back ticks will be dropped from elements if
+ * they exist.
*
- * @param col1 The name of the first column. Distinct items will make the first item of
- * each row.
- * @param col2 The name of the second column. Distinct items will make the column names
- * of the DataFrame.
- * @return A DataFrame containing for the contingency table.
+ * @param col1
+ * The name of the first column. Distinct items will make the first item of each row.
+ * @param col2
+ * The name of the second column. Distinct items will make the column names of the DataFrame.
+ * @return
+ * A DataFrame containing for the contingency table.
*
* {{{
* val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3)))
@@ -184,22 +202,22 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 1.4.0
*/
- def crosstab(col1: String, col2: String): DS[Row]
+ def crosstab(col1: String, col2: String): Dataset[Row]
/**
- * Finding frequent items for columns, possibly with false positives. Using the
- * frequent element count algorithm described in
- * here, proposed by Karp,
- * Schenker, and Papadimitriou.
- * The `support` should be greater than 1e-4.
+ * Finding frequent items for columns, possibly with false positives. Using the frequent element
+ * count algorithm described in here,
+ * proposed by Karp, Schenker, and Papadimitriou. The `support` should be greater than 1e-4.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting `DataFrame`.
*
- * @param cols the names of the columns to search frequent items in.
- * @param support The minimum frequency for an item to be considered `frequent`. Should be greater
- * than 1e-4.
- * @return A Local DataFrame with the Array of frequent items for each column.
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @param support
+ * The minimum frequency for an item to be considered `frequent`. Should be greater than 1e-4.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
*
* {{{
* val rows = Seq.tabulate(100) { i =>
@@ -228,36 +246,38 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
* }}}
* @since 1.4.0
*/
- def freqItems(cols: Array[String], support: Double): DS[Row] =
+ def freqItems(cols: Array[String], support: Double): Dataset[Row] =
freqItems(cols.toImmutableArraySeq, support)
/**
- * Finding frequent items for columns, possibly with false positives. Using the
- * frequent element count algorithm described in
- * here, proposed by Karp,
- * Schenker, and Papadimitriou.
- * Uses a `default` support of 1%.
+ * Finding frequent items for columns, possibly with false positives. Using the frequent element
+ * count algorithm described in here,
+ * proposed by Karp, Schenker, and Papadimitriou. Uses a `default` support of 1%.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting `DataFrame`.
*
- * @param cols the names of the columns to search frequent items in.
- * @return A Local DataFrame with the Array of frequent items for each column.
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
* @since 1.4.0
*/
- def freqItems(cols: Array[String]): DS[Row] = freqItems(cols, 0.01)
+ def freqItems(cols: Array[String]): Dataset[Row] = freqItems(cols, 0.01)
/**
* (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
- * frequent element count algorithm described in
- * here, proposed by Karp, Schenker,
- * and Papadimitriou.
+ * frequent element count algorithm described in here, proposed by Karp, Schenker, and
+ * Papadimitriou.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting `DataFrame`.
*
- * @param cols the names of the columns to search frequent items in.
- * @return A Local DataFrame with the Array of frequent items for each column.
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
*
* {{{
* val rows = Seq.tabulate(100) { i =>
@@ -287,32 +307,38 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 1.4.0
*/
- def freqItems(cols: Seq[String], support: Double): DS[Row]
+ def freqItems(cols: Seq[String], support: Double): Dataset[Row]
/**
* (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the
- * frequent element count algorithm described in
- * here, proposed by Karp, Schenker,
- * and Papadimitriou.
- * Uses a `default` support of 1%.
+ * frequent element count algorithm described in here, proposed by Karp, Schenker, and
+ * Papadimitriou. Uses a `default` support of 1%.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting `DataFrame`.
*
- * @param cols the names of the columns to search frequent items in.
- * @return A Local DataFrame with the Array of frequent items for each column.
+ * @param cols
+ * the names of the columns to search frequent items in.
+ * @return
+ * A Local DataFrame with the Array of frequent items for each column.
* @since 1.4.0
*/
- def freqItems(cols: Seq[String]): DS[Row] = freqItems(cols, 0.01)
+ def freqItems(cols: Seq[String]): Dataset[Row] = freqItems(cols, 0.01)
/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
*
* {{{
* val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2),
@@ -330,33 +356,43 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 1.5.0
*/
- def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DS[Row] = {
+ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): Dataset[Row] = {
sampleBy(Column(col), fractions, seed)
}
/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
*
* @since 1.5.0
*/
- def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = {
+ def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}
/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
*
* The stratified sample can be performed over multiple columns:
* {{{
@@ -377,33 +413,42 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
*
* @since 3.0.0
*/
- def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DS[Row]
-
+ def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): Dataset[Row]
/**
* (Java-specific) Returns a stratified sample without replacement based on the fraction given
* on each stratum.
*
- * @param col column that defines strata
- * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
- * its fraction as zero.
- * @param seed random seed
- * @tparam T stratum type
- * @return a new `DataFrame` that represents the stratified sample
+ * @param col
+ * column that defines strata
+ * @param fractions
+ * sampling fraction for each stratum. If a stratum is not specified, we treat its fraction as
+ * zero.
+ * @param seed
+ * random seed
+ * @tparam T
+ * stratum type
+ * @return
+ * a new `DataFrame` that represents the stratified sample
* @since 3.0.0
*/
- def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = {
+ def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = {
sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
}
/**
* Builds a Count-min Sketch over a specified column.
*
- * @param colName name of the column over which the sketch is built
- * @param depth depth of the sketch
- * @param width width of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
+ * @param colName
+ * name of the column over which the sketch is built
+ * @param depth
+ * depth of the sketch
+ * @param width
+ * width of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
* @since 2.0.0
*/
def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
@@ -413,26 +458,39 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Builds a Count-min Sketch over a specified column.
*
- * @param colName name of the column over which the sketch is built
- * @param eps relative error of the sketch
- * @param confidence confidence of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
+ * @param colName
+ * name of the column over which the sketch is built
+ * @param eps
+ * relative error of the sketch
+ * @param confidence
+ * confidence of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
* @since 2.0.0
*/
def countMinSketch(
- colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+ colName: String,
+ eps: Double,
+ confidence: Double,
+ seed: Int): CountMinSketch = {
countMinSketch(Column(colName), eps, confidence, seed)
}
/**
* Builds a Count-min Sketch over a specified column.
*
- * @param col the column over which the sketch is built
- * @param depth depth of the sketch
- * @param width width of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
+ * @param col
+ * the column over which the sketch is built
+ * @param depth
+ * depth of the sketch
+ * @param width
+ * width of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
* @since 2.0.0
*/
def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
@@ -444,29 +502,34 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Builds a Count-min Sketch over a specified column.
*
- * @param col the column over which the sketch is built
- * @param eps relative error of the sketch
- * @param confidence confidence of the sketch
- * @param seed random seed
- * @return a `CountMinSketch` over column `colName`
+ * @param col
+ * the column over which the sketch is built
+ * @param eps
+ * relative error of the sketch
+ * @param confidence
+ * confidence of the sketch
+ * @param seed
+ * random seed
+ * @return
+ * a `CountMinSketch` over column `colName`
* @since 2.0.0
*/
- def countMinSketch(
- col: Column,
- eps: Double,
- confidence: Double,
- seed: Int): CountMinSketch = withOrigin {
- val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed))
- val bytes: Array[Byte] = df.select(cms).as(BinaryEncoder).head()
- CountMinSketch.readFrom(bytes)
- }
+ def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch =
+ withOrigin {
+ val cms = count_min_sketch(col, lit(eps), lit(confidence), lit(seed))
+ val bytes: Array[Byte] = df.select(cms).as(BinaryEncoder).head()
+ CountMinSketch.readFrom(bytes)
+ }
/**
* Builds a Bloom filter over a specified column.
*
- * @param colName name of the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into the filter.
- * @param fpp expected false positive probability of the filter.
+ * @param colName
+ * name of the column over which the filter is built
+ * @param expectedNumItems
+ * expected number of items which will be put into the filter.
+ * @param fpp
+ * expected false positive probability of the filter.
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
@@ -476,9 +539,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Builds a Bloom filter over a specified column.
*
- * @param col the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into the filter.
- * @param fpp expected false positive probability of the filter.
+ * @param col
+ * the column over which the filter is built
+ * @param expectedNumItems
+ * expected number of items which will be put into the filter.
+ * @param fpp
+ * expected false positive probability of the filter.
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
@@ -489,9 +555,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Builds a Bloom filter over a specified column.
*
- * @param colName name of the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into the filter.
- * @param numBits expected number of bits of the filter.
+ * @param colName
+ * name of the column over which the filter is built
+ * @param expectedNumItems
+ * expected number of items which will be put into the filter.
+ * @param numBits
+ * expected number of bits of the filter.
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
@@ -501,9 +570,12 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Builds a Bloom filter over a specified column.
*
- * @param col the column over which the filter is built
- * @param expectedNumItems expected number of items which will be put into the filter.
- * @param numBits expected number of bits of the filter.
+ * @param col
+ * the column over which the filter is built
+ * @param expectedNumItems
+ * expected number of items which will be put into the filter.
+ * @param numBits
+ * expected number of bits of the filter.
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 49f77a1a61204..6eef034aa5157 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -23,15 +23,16 @@ import _root_.java.util
import org.apache.spark.annotation.{DeveloperApi, Stable}
import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction}
-import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, Encoder, Observation, Row, TypedColumn}
+import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row, TypedColumn}
+import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors}
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkClassUtils
/**
- * A Dataset is a strongly typed collection of domain-specific objects that can be transformed
- * in parallel using functional or relational operations. Each Dataset also has an untyped view
+ * A Dataset is a strongly typed collection of domain-specific objects that can be transformed in
+ * parallel using functional or relational operations. Each Dataset also has an untyped view
* called a `DataFrame`, which is a Dataset of [[org.apache.spark.sql.Row]].
*
* Operations available on Datasets are divided into transformations and actions. Transformations
@@ -39,29 +40,29 @@ import org.apache.spark.util.SparkClassUtils
* return results. Example transformations include map, filter, select, and aggregate (`groupBy`).
* Example actions count, show, or writing data out to file systems.
*
- * Datasets are "lazy", i.e. computations are only triggered when an action is invoked. Internally,
- * a Dataset represents a logical plan that describes the computation required to produce the data.
- * When an action is invoked, Spark's query optimizer optimizes the logical plan and generates a
- * physical plan for efficient execution in a parallel and distributed manner. To explore the
- * logical plan as well as optimized physical plan, use the `explain` function.
+ * Datasets are "lazy", i.e. computations are only triggered when an action is invoked.
+ * Internally, a Dataset represents a logical plan that describes the computation required to
+ * produce the data. When an action is invoked, Spark's query optimizer optimizes the logical plan
+ * and generates a physical plan for efficient execution in a parallel and distributed manner. To
+ * explore the logical plan as well as optimized physical plan, use the `explain` function.
*
- * To efficiently support domain-specific objects, an [[org.apache.spark.sql.Encoder]] is required.
- * The encoder maps the domain specific type `T` to Spark's internal type system. For example, given
- * a class `Person` with two fields, `name` (string) and `age` (int), an encoder is used to tell
- * Spark to generate code at runtime to serialize the `Person` object into a binary structure. This
- * binary structure often has much lower memory footprint as well as are optimized for efficiency
- * in data processing (e.g. in a columnar format). To understand the internal binary representation
- * for data, use the `schema` function.
+ * To efficiently support domain-specific objects, an [[org.apache.spark.sql.Encoder]] is
+ * required. The encoder maps the domain specific type `T` to Spark's internal type system. For
+ * example, given a class `Person` with two fields, `name` (string) and `age` (int), an encoder is
+ * used to tell Spark to generate code at runtime to serialize the `Person` object into a binary
+ * structure. This binary structure often has much lower memory footprint as well as are optimized
+ * for efficiency in data processing (e.g. in a columnar format). To understand the internal
+ * binary representation for data, use the `schema` function.
*
- * There are typically two ways to create a Dataset. The most common way is by pointing Spark
- * to some files on storage systems, using the `read` function available on a `SparkSession`.
+ * There are typically two ways to create a Dataset. The most common way is by pointing Spark to
+ * some files on storage systems, using the `read` function available on a `SparkSession`.
* {{{
* val people = spark.read.parquet("...").as[Person] // Scala
* Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java
* }}}
*
- * Datasets can also be created through transformations available on existing Datasets. For example,
- * the following creates a new Dataset by applying a filter on the existing one:
+ * Datasets can also be created through transformations available on existing Datasets. For
+ * example, the following creates a new Dataset by applying a filter on the existing one:
* {{{
* val names = people.map(_.name) // in Scala; names is a Dataset[String]
* Dataset names = people.map(
@@ -70,8 +71,8 @@ import org.apache.spark.util.SparkClassUtils
*
* Dataset operations can also be untyped, through various domain-specific-language (DSL)
* functions defined in: Dataset (this class), [[org.apache.spark.sql.Column]], and
- * [[org.apache.spark.sql.functions]]. These operations are very similar to the operations available
- * in the data frame abstraction in R or Python.
+ * [[org.apache.spark.sql.functions]]. These operations are very similar to the operations
+ * available in the data frame abstraction in R or Python.
*
* To select a column from the Dataset, use `apply` method in Scala and `col` in Java.
* {{{
@@ -118,10 +119,10 @@ import org.apache.spark.util.SparkClassUtils
* @since 1.6.0
*/
@Stable
-abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
- type RGD <: RelationalGroupedDataset[DS]
+abstract class Dataset[T] extends Serializable {
+ type DS[U] <: Dataset[U]
- def sparkSession: SparkSession[DS]
+ def sparkSession: SparkSession
val encoder: Encoder[T]
@@ -135,52 +136,46 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
*/
// This is declared with parentheses to prevent the Scala compiler from treating
// `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDF(): DS[Row]
-
- /**
- * Returns a new Dataset where each record has been mapped on to the specified type. The
- * method used to map columns depend on the type of `U`:
- *
- *
When `U` is a class, fields for the class will be mapped to columns of the same name
- * (case sensitivity is determined by `spark.sql.caseSensitive`).
- *
When `U` is a tuple, the columns will be mapped by ordinal (i.e. the first column will
- * be assigned to `_1`).
- *
When `U` is a primitive type (i.e. String, Int, etc), then the first column of the
- * `DataFrame` will be used.
+ def toDF(): Dataset[Row]
+
+ /**
+ * Returns a new Dataset where each record has been mapped on to the specified type. The method
+ * used to map columns depend on the type of `U`:
When `U` is a class, fields for the
+ * class will be mapped to columns of the same name (case sensitivity is determined by
+ * `spark.sql.caseSensitive`).
When `U` is a tuple, the columns will be mapped by
+ * ordinal (i.e. the first column will be assigned to `_1`).
When `U` is a primitive
+ * type (i.e. String, Int, etc), then the first column of the `DataFrame` will be used.
*
*
- * If the schema of the Dataset does not match the desired `U` type, you can use `select`
- * along with `alias` or `as` to rearrange or rename as required.
+ * If the schema of the Dataset does not match the desired `U` type, you can use `select` along
+ * with `alias` or `as` to rearrange or rename as required.
*
- * Note that `as[]` only changes the view of the data that is passed into typed operations,
- * such as `map()`, and does not eagerly project away any columns that are not present in
- * the specified class.
+ * Note that `as[]` only changes the view of the data that is passed into typed operations, such
+ * as `map()`, and does not eagerly project away any columns that are not present in the
+ * specified class.
*
* @group basic
* @since 1.6.0
*/
- def as[U: Encoder]: DS[U]
+ def as[U: Encoder]: Dataset[U]
/**
- * Returns a new DataFrame where each row is reconciled to match the specified schema. Spark will:
- *
- *
Reorder columns and/or inner fields by name to match the specified schema.
- *
Project away columns and/or inner fields that are not needed by the specified schema.
- * Missing columns and/or inner fields (present in the specified schema but not input DataFrame)
- * lead to failures.
- *
Cast the columns and/or inner fields to match the data types in the specified schema, if
- * the types are compatible, e.g., numeric to numeric (error if overflows), but not string to
- * int.
- *
Carry over the metadata from the specified schema, while the columns and/or inner fields
- * still keep their own metadata if not overwritten by the specified schema.
- *
Fail if the nullability is not compatible. For example, the column and/or inner field is
- * nullable but the specified schema requires them to be not nullable.
- *
+ * Returns a new DataFrame where each row is reconciled to match the specified schema. Spark
+ * will:
Reorder columns and/or inner fields by name to match the specified
+ * schema.
Project away columns and/or inner fields that are not needed by the
+ * specified schema. Missing columns and/or inner fields (present in the specified schema but
+ * not input DataFrame) lead to failures.
Cast the columns and/or inner fields to match
+ * the data types in the specified schema, if the types are compatible, e.g., numeric to numeric
+ * (error if overflows), but not string to int.
Carry over the metadata from the
+ * specified schema, while the columns and/or inner fields still keep their own metadata if not
+ * overwritten by the specified schema.
Fail if the nullability is not compatible. For
+ * example, the column and/or inner field is nullable but the specified schema requires them to
+ * be not nullable.
*
* @group basic
* @since 3.4.0
*/
- def to(schema: StructType): DS[Row]
+ def to(schema: StructType): Dataset[Row]
/**
* Converts this strongly typed collection of data to generic `DataFrame` with columns renamed.
@@ -196,7 +191,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def toDF(colNames: String*): DS[Row]
+ def toDF(colNames: String*): Dataset[Row]
/**
* Returns the schema of this Dataset.
@@ -228,16 +223,12 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Prints the plans (logical and physical) with a format specified by a given explain mode.
*
- * @param mode specifies the expected output format of plans.
- *
- *
`simple` Print only a physical plan.
- *
`extended`: Print both logical and physical plans.
- *
`codegen`: Print a physical plan and generated codes if they are
- * available.
- *
`cost`: Print a logical plan and statistics if they are available.
- *
`formatted`: Split explain output into two sections: a physical plan outline
- * and node details.
- *
+ * @param mode
+ * specifies the expected output format of plans.
`simple` Print only a physical
+ * plan.
`extended`: Print both logical and physical plans.
`codegen`: Print
+ * a physical plan and generated codes if they are available.
`cost`: Print a logical
+ * plan and statistics if they are available.
`formatted`: Split explain output into
+ * two sections: a physical plan outline and node details.
* @group basic
* @since 3.0.0
*/
@@ -246,12 +237,12 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Prints the plans (logical and physical) to the console for debugging purposes.
*
- * @param extended default `false`. If `false`, prints only the physical plan.
+ * @param extended
+ * default `false`. If `false`, prints only the physical plan.
* @group basic
* @since 1.6.0
*/
def explain(extended: Boolean): Unit = if (extended) {
- // TODO move ExplainMode?
explain("extended")
} else {
explain("simple")
@@ -284,8 +275,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def columns: Array[String] = schema.fields.map(_.name)
/**
- * Returns true if the `collect` and `take` methods can be run locally
- * (without any Spark executors).
+ * Returns true if the `collect` and `take` methods can be run locally (without any Spark
+ * executors).
*
* @group basic
* @since 1.6.0
@@ -301,12 +292,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def isEmpty: Boolean
/**
- * Returns true if this Dataset contains one or more sources that continuously
- * return data as it arrives. A Dataset that reads data from a streaming source
- * must be executed as a `StreamingQuery` using the `start()` method in
- * `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or
- * `collect()`, will throw an [[org.apache.spark.sql.AnalysisException]] when there is a
- * streaming source present.
+ * Returns true if this Dataset contains one or more sources that continuously return data as it
+ * arrives. A Dataset that reads data from a streaming source must be executed as a
+ * `StreamingQuery` using the `start()` method in `DataStreamWriter`. Methods that return a
+ * single answer, e.g. `count()` or `collect()`, will throw an
+ * [[org.apache.spark.sql.AnalysisException]] when there is a streaming source present.
*
* @group streaming
* @since 2.0.0
@@ -314,103 +304,106 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def isStreaming: Boolean
/**
- * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate
- * the logical plan of this Dataset, which is especially useful in iterative algorithms where the
- * plan may grow exponentially. It will be saved to files inside the checkpoint
+ * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to
+ * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms
+ * where the plan may grow exponentially. It will be saved to files inside the checkpoint
* directory set with `SparkContext#setCheckpointDir`.
*
* @group basic
* @since 2.1.0
*/
- def checkpoint(): DS[T] = checkpoint(eager = true, reliableCheckpoint = true)
+ def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)
/**
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
* logical plan of this Dataset, which is especially useful in iterative algorithms where the
- * plan may grow exponentially. It will be saved to files inside the checkpoint
- * directory set with `SparkContext#setCheckpointDir`.
- *
- * @param eager Whether to checkpoint this dataframe immediately
- * @note When checkpoint is used with eager = false, the final data that is checkpointed after
- * the first action may be different from the data that was used during the job due to
- * non-determinism of the underlying operation and retries. If checkpoint is used to achieve
- * saving a deterministic snapshot of the data, eager = true should be used. Otherwise,
- * it is only deterministic after the first execution, after the checkpoint was finalized.
+ * plan may grow exponentially. It will be saved to files inside the checkpoint directory set
+ * with `SparkContext#setCheckpointDir`.
+ *
+ * @param eager
+ * Whether to checkpoint this dataframe immediately
+ * @note
+ * When checkpoint is used with eager = false, the final data that is checkpointed after the
+ * first action may be different from the data that was used during the job due to
+ * non-determinism of the underlying operation and retries. If checkpoint is used to achieve
+ * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is
+ * only deterministic after the first execution, after the checkpoint was finalized.
* @group basic
* @since 2.1.0
*/
- def checkpoint(eager: Boolean): DS[T] = checkpoint(eager = eager, reliableCheckpoint = true)
+ def checkpoint(eager: Boolean): Dataset[T] =
+ checkpoint(eager = eager, reliableCheckpoint = true)
/**
- * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be
- * used to truncate the logical plan of this Dataset, which is especially useful in iterative
+ * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used
+ * to truncate the logical plan of this Dataset, which is especially useful in iterative
* algorithms where the plan may grow exponentially. Local checkpoints are written to executor
* storage and despite potentially faster they are unreliable and may compromise job completion.
*
* @group basic
* @since 2.3.0
*/
- def localCheckpoint(): DS[T] = checkpoint(eager = true, reliableCheckpoint = false)
+ def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)
/**
- * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate
- * the logical plan of this Dataset, which is especially useful in iterative algorithms where the
- * plan may grow exponentially. Local checkpoints are written to executor storage and despite
- * potentially faster they are unreliable and may compromise job completion.
+ * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to
+ * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms
+ * where the plan may grow exponentially. Local checkpoints are written to executor storage and
+ * despite potentially faster they are unreliable and may compromise job completion.
*
- * @param eager Whether to checkpoint this dataframe immediately
- * @note When checkpoint is used with eager = false, the final data that is checkpointed after
- * the first action may be different from the data that was used during the job due to
- * non-determinism of the underlying operation and retries. If checkpoint is used to achieve
- * saving a deterministic snapshot of the data, eager = true should be used. Otherwise,
- * it is only deterministic after the first execution, after the checkpoint was finalized.
+ * @param eager
+ * Whether to checkpoint this dataframe immediately
+ * @note
+ * When checkpoint is used with eager = false, the final data that is checkpointed after the
+ * first action may be different from the data that was used during the job due to
+ * non-determinism of the underlying operation and retries. If checkpoint is used to achieve
+ * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is
+ * only deterministic after the first execution, after the checkpoint was finalized.
* @group basic
* @since 2.3.0
*/
- def localCheckpoint(eager: Boolean): DS[T] = checkpoint(
- eager = eager,
- reliableCheckpoint = false
- )
+ def localCheckpoint(eager: Boolean): Dataset[T] =
+ checkpoint(eager = eager, reliableCheckpoint = false)
/**
* Returns a checkpointed version of this Dataset.
*
- * @param eager Whether to checkpoint this dataframe immediately
- * @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the
- * checkpoint directory. If false creates a local checkpoint using
- * the caching subsystem
+ * @param eager
+ * Whether to checkpoint this dataframe immediately
+ * @param reliableCheckpoint
+ * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If
+ * false creates a local checkpoint using the caching subsystem
*/
- protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): DS[T]
+ protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T]
/**
* Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time
* before which we assume no more late data is going to arrive.
*
- * Spark will use this watermark for several purposes:
- *
- *
To know when a given time window aggregation can be finalized and thus can be emitted
- * when using output modes that do not allow updates.
- *
To minimize the amount of state that we need to keep for on-going aggregations,
- * `mapGroupsWithState` and `dropDuplicates` operators.
- *
- * The current watermark is computed by looking at the `MAX(eventTime)` seen across
- * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost
- * of coordinating this value across partitions, the actual watermark used is only guaranteed
- * to be at least `delayThreshold` behind the actual event time. In some cases we may still
- * process records that arrive more than `delayThreshold` late.
- *
- * @param eventTime the name of the column that contains the event time of the row.
- * @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest
- * record that has been processed in the form of an interval
- * (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative.
+ * Spark will use this watermark for several purposes:
To know when a given time window
+ * aggregation can be finalized and thus can be emitted when using output modes that do not
+ * allow updates.
To minimize the amount of state that we need to keep for on-going
+ * aggregations, `mapGroupsWithState` and `dropDuplicates` operators.
The current
+ * watermark is computed by looking at the `MAX(eventTime)` seen across all of the partitions in
+ * the query minus a user specified `delayThreshold`. Due to the cost of coordinating this value
+ * across partitions, the actual watermark used is only guaranteed to be at least
+ * `delayThreshold` behind the actual event time. In some cases we may still process records
+ * that arrive more than `delayThreshold` late.
+ *
+ * @param eventTime
+ * the name of the column that contains the event time of the row.
+ * @param delayThreshold
+ * the minimum delay to wait to data to arrive late, relative to the latest record that has
+ * been processed in the form of an interval (e.g. "1 minute" or "5 hours"). NOTE: This should
+ * not be negative.
* @group streaming
* @since 2.1.0
*/
// We only accept an existing column name, not a derived column here as a watermark that is
// defined on a derived column cannot referenced elsewhere in the plan.
- def withWatermark(eventTime: String, delayThreshold: String): DS[T]
+ def withWatermark(eventTime: String, delayThreshold: String): Dataset[T]
- /**
+ /**
* Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
* and all cells will be aligned right. For example:
* {{{
@@ -422,7 +415,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* 1984 04 0.450090 0.483521
* }}}
*
- * @param numRows Number of rows to show
+ * @param numRows
+ * Number of rows to show
*
* @group action
* @since 1.6.0
@@ -430,8 +424,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def show(numRows: Int): Unit = show(numRows, truncate = true)
/**
- * Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters
- * will be truncated, and all cells will be aligned right.
+ * Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters will
+ * be truncated, and all cells will be aligned right.
*
* @group action
* @since 1.6.0
@@ -441,8 +435,9 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Displays the top 20 rows of Dataset in a tabular form.
*
- * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
- * be truncated and all cells will be aligned right
+ * @param truncate
+ * Whether truncate long strings. If true, strings more than 20 characters will be truncated
+ * and all cells will be aligned right
*
* @group action
* @since 1.6.0
@@ -459,9 +454,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* 1983 03 0.410516 0.442194
* 1984 04 0.450090 0.483521
* }}}
- * @param numRows Number of rows to show
- * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
- * be truncated and all cells will be aligned right
+ * @param numRows
+ * Number of rows to show
+ * @param truncate
+ * Whether truncate long strings. If true, strings more than 20 characters will be truncated
+ * and all cells will be aligned right
*
* @group action
* @since 1.6.0
@@ -480,9 +477,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* 1984 04 0.450090 0.483521
* }}}
*
- * @param numRows Number of rows to show
- * @param truncate If set to more than 0, truncates strings to `truncate` characters and
- * all cells will be aligned right.
+ * @param numRows
+ * Number of rows to show
+ * @param truncate
+ * If set to more than 0, truncates strings to `truncate` characters and all cells will be
+ * aligned right.
* @group action
* @since 1.6.0
*/
@@ -499,7 +498,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* 1984 04 0.450090 0.483521
* }}}
*
- * If `vertical` enabled, this command prints output rows vertically (one line per column value)?
+ * If `vertical` enabled, this command prints output rows vertically (one line per column
+ * value)?
*
* {{{
* -RECORD 0-------------------
@@ -529,10 +529,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* AVG('Adj Close) | 0.483521
* }}}
*
- * @param numRows Number of rows to show
- * @param truncate If set to more than 0, truncates strings to `truncate` characters and
- * all cells will be aligned right.
- * @param vertical If set to true, prints output rows vertically (one line per column value).
+ * @param numRows
+ * Number of rows to show
+ * @param truncate
+ * If set to more than 0, truncates strings to `truncate` characters and all cells will be
+ * aligned right.
+ * @param vertical
+ * If set to true, prints output rows vertically (one line per column value).
* @group action
* @since 2.3.0
*/
@@ -549,7 +552,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 1.6.0
*/
- def na: DataFrameNaFunctions[DS]
+ def na: DataFrameNaFunctions
/**
* Returns a [[DataFrameStatFunctions]] for working statistic functions support.
@@ -561,20 +564,21 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 1.6.0
*/
- def stat: DataFrameStatFunctions[DS]
+ def stat: DataFrameStatFunctions
/**
* Join with another `DataFrame`.
*
* Behaves as an INNER JOIN and requires a subsequent join predicate.
*
- * @param right Right side of the join operation.
+ * @param right
+ * Right side of the join operation.
* @group untypedrel
* @since 2.0.0
*/
- def join(right: DS[_]): DS[Row]
+ def join(right: DS[_]): Dataset[Row]
- /**
+ /**
* Inner equi-join with another `DataFrame` using the given column.
*
* Different from other join functions, the join column will only appear once in the output,
@@ -585,17 +589,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* df1.join(df2, "user_id")
* }}}
*
- * @param right Right side of the join operation.
- * @param usingColumn Name of the column to join on. This column must exist on both sides.
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumn
+ * Name of the column to join on. This column must exist on both sides.
*
- * @note If you perform a self-join using this function without aliasing the input
- * `DataFrame`s, you will NOT be able to reference any columns after the join, since
- * there is no way to disambiguate which side of the join you would like to reference.
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 2.0.0
*/
- def join(right: DS[_], usingColumn: String): DS[Row] = {
+ def join(right: DS[_], usingColumn: String): Dataset[Row] = {
join(right, Seq(usingColumn))
}
@@ -603,13 +610,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* (Java-specific) Inner equi-join with another `DataFrame` using the given columns. See the
* Scala-specific overload for more details.
*
- * @param right Right side of the join operation.
- * @param usingColumns Names of the columns to join on. This columns must exist on both sides.
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
*
* @group untypedrel
* @since 3.4.0
*/
- def join(right: DS[_], usingColumns: Array[String]): DS[Row] = {
+ def join(right: DS[_], usingColumns: Array[String]): Dataset[Row] = {
join(right, usingColumns.toImmutableArraySeq)
}
@@ -624,43 +633,50 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* df1.join(df2, Seq("user_id", "user_name"))
* }}}
*
- * @param right Right side of the join operation.
- * @param usingColumns Names of the columns to join on. This columns must exist on both sides.
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
*
- * @note If you perform a self-join using this function without aliasing the input
- * `DataFrame`s, you will NOT be able to reference any columns after the join, since
- * there is no way to disambiguate which side of the join you would like to reference.
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 2.0.0
*/
- def join(right: DS[_], usingColumns: Seq[String]): DS[Row] = {
+ def join(right: DS[_], usingColumns: Seq[String]): Dataset[Row] = {
join(right, usingColumns, "inner")
}
/**
- * Equi-join with another `DataFrame` using the given column. A cross join with a predicate
- * is specified as an inner join. If you would explicitly like to perform a cross join use the
+ * Equi-join with another `DataFrame` using the given column. A cross join with a predicate is
+ * specified as an inner join. If you would explicitly like to perform a cross join use the
* `crossJoin` method.
*
* Different from other join functions, the join column will only appear once in the output,
* i.e. similar to SQL's `JOIN USING` syntax.
*
- * @param right Right side of the join operation.
- * @param usingColumn Name of the column to join on. This column must exist on both sides.
- * @param joinType Type of join to perform. Default `inner`. Must be one of:
- * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
- * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
- * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`.
- *
- * @note If you perform a self-join using this function without aliasing the input
- * `DataFrame`s, you will NOT be able to reference any columns after the join, since
- * there is no way to disambiguate which side of the join you would like to reference.
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumn
+ * Name of the column to join on. This column must exist on both sides.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
+ *
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 3.4.0
*/
- def join(right: DS[_], usingColumn: String, joinType: String): DS[Row] = {
+ def join(right: DS[_], usingColumn: String, joinType: String): Dataset[Row] = {
join(right, Seq(usingColumn), joinType)
}
@@ -668,17 +684,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* (Java-specific) Equi-join with another `DataFrame` using the given columns. See the
* Scala-specific overload for more details.
*
- * @param right Right side of the join operation.
- * @param usingColumns Names of the columns to join on. This columns must exist on both sides.
- * @param joinType Type of join to perform. Default `inner`. Must be one of:
- * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
- * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
- * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`.
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
*
* @group untypedrel
* @since 3.4.0
*/
- def join(right: DS[_], usingColumns: Array[String], joinType: String): DS[Row] = {
+ def join(right: DS[_], usingColumns: Array[String], joinType: String): Dataset[Row] = {
join(right, usingColumns.toImmutableArraySeq, joinType)
}
@@ -690,21 +709,25 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* Different from other join functions, the join columns will only appear once in the output,
* i.e. similar to SQL's `JOIN USING` syntax.
*
- * @param right Right side of the join operation.
- * @param usingColumns Names of the columns to join on. This columns must exist on both sides.
- * @param joinType Type of join to perform. Default `inner`. Must be one of:
- * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
- * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
- * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`.
- *
- * @note If you perform a self-join using this function without aliasing the input
- * `DataFrame`s, you will NOT be able to reference any columns after the join, since
- * there is no way to disambiguate which side of the join you would like to reference.
+ * @param right
+ * Right side of the join operation.
+ * @param usingColumns
+ * Names of the columns to join on. This columns must exist on both sides.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
+ *
+ * @note
+ * If you perform a self-join using this function without aliasing the input `DataFrame`s, you
+ * will NOT be able to reference any columns after the join, since there is no way to
+ * disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 2.0.0
*/
- def join(right: DS[_], usingColumns: Seq[String], joinType: String): DS[Row]
+ def join(right: DS[_], usingColumns: Seq[String], joinType: String): Dataset[Row]
/**
* Inner join with another `DataFrame`, using the given join expression.
@@ -718,12 +741,12 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 2.0.0
*/
- def join(right: DS[_], joinExprs: Column): DS[Row] =
+ def join(right: DS[_], joinExprs: Column): Dataset[Row] =
join(right, joinExprs, "inner")
/**
- * Join with another `DataFrame`, using the given join expression. The following performs
- * a full outer join between `df1` and `df2`.
+ * Join with another `DataFrame`, using the given join expression. The following performs a full
+ * outer join between `df1` and `df2`.
*
* {{{
* // Scala:
@@ -735,64 +758,73 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer");
* }}}
*
- * @param right Right side of the join.
- * @param joinExprs Join expression.
- * @param joinType Type of join to perform. Default `inner`. Must be one of:
- * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
- * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
- * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`.
+ * @param right
+ * Right side of the join.
+ * @param joinExprs
+ * Join expression.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`, `full_outer`, `left`, `leftouter`, `left_outer`, `right`,
+ * `rightouter`, `right_outer`, `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
+ * `left_anti`.
*
* @group untypedrel
* @since 2.0.0
*/
- def join(right: DS[_], joinExprs: Column, joinType: String): DS[Row]
+ def join(right: DS[_], joinExprs: Column, joinType: String): Dataset[Row]
/**
* Explicit cartesian join with another `DataFrame`.
*
- * @param right Right side of the join operation.
- * @note Cartesian joins are very expensive without an extra filter that can be pushed down.
+ * @param right
+ * Right side of the join operation.
+ * @note
+ * Cartesian joins are very expensive without an extra filter that can be pushed down.
* @group untypedrel
* @since 2.1.0
*/
- def crossJoin(right: DS[_]): DS[Row]
+ def crossJoin(right: DS[_]): Dataset[Row]
/**
- * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to
- * true.
+ * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true.
*
- * This is similar to the relation `join` function with one important difference in the
- * result schema. Since `joinWith` preserves objects present on either side of the join, the
- * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+ * This is similar to the relation `join` function with one important difference in the result
+ * schema. Since `joinWith` preserves objects present on either side of the join, the result
+ * schema is similarly nested into a tuple under the column names `_1` and `_2`.
*
* This type of join can be useful both for preserving type-safety with the original object
- * types as well as working with relational data where either side of the join has column
- * names in common.
- *
- * @param other Right side of the join.
- * @param condition Join expression.
- * @param joinType Type of join to perform. Default `inner`. Must be one of:
- * `inner`, `cross`, `outer`, `full`, `fullouter`,`full_outer`, `left`,
- * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`.
+ * types as well as working with relational data where either side of the join has column names
+ * in common.
+ *
+ * @param other
+ * Right side of the join.
+ * @param condition
+ * Join expression.
+ * @param joinType
+ * Type of join to perform. Default `inner`. Must be one of: `inner`, `cross`, `outer`,
+ * `full`, `fullouter`,`full_outer`, `left`, `leftouter`, `left_outer`, `right`, `rightouter`,
+ * `right_outer`.
* @group typedrel
* @since 1.6.0
*/
- def joinWith[U](other: DS[U], condition: Column, joinType: String): DS[(T, U)]
+ def joinWith[U](other: DS[U], condition: Column, joinType: String): Dataset[(T, U)]
/**
- * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair
- * where `condition` evaluates to true.
+ * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair where
+ * `condition` evaluates to true.
*
- * @param other Right side of the join.
- * @param condition Join expression.
+ * @param other
+ * Right side of the join.
+ * @param condition
+ * Join expression.
* @group typedrel
* @since 1.6.0
*/
- def joinWith[U](other: DS[U], condition: Column): DS[(T, U)] = {
+ def joinWith[U](other: DS[U], condition: Column): Dataset[(T, U)] = {
joinWith(other, condition, "inner")
}
- protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): DS[T]
+ protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T]
/**
* Returns a new Dataset with each partition sorted by the given expressions.
@@ -803,8 +835,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def sortWithinPartitions(sortCol: String, sortCols: String*): DS[T] = {
- sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*)
+ def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
+ sortWithinPartitions((sortCol +: sortCols).map(Column(_)): _*)
}
/**
@@ -816,7 +848,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def sortWithinPartitions(sortExprs: Column*): DS[T] = {
+ def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
sortInternal(global = false, sortExprs)
}
@@ -833,8 +865,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def sort(sortCol: String, sortCols: String*): DS[T] = {
- sort((sortCol +: sortCols).map(Column(_)) : _*)
+ def sort(sortCol: String, sortCols: String*): Dataset[T] = {
+ sort((sortCol +: sortCols).map(Column(_)): _*)
}
/**
@@ -847,33 +879,33 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def sort(sortExprs: Column*): DS[T] = {
+ def sort(sortExprs: Column*): Dataset[T] = {
sortInternal(global = true, sortExprs)
}
/**
- * Returns a new Dataset sorted by the given expressions.
- * This is an alias of the `sort` function.
+ * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort`
+ * function.
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
- def orderBy(sortCol: String, sortCols: String*): DS[T] = sort(sortCol, sortCols : _*)
+ def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols: _*)
/**
- * Returns a new Dataset sorted by the given expressions.
- * This is an alias of the `sort` function.
+ * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort`
+ * function.
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
- def orderBy(sortExprs: Column*): DS[T] = sort(sortExprs : _*)
+ def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs: _*)
/**
- * Specifies some hint on the current Dataset. As an example, the following code specifies
- * that one of the plan can be broadcasted:
+ * Specifies some hint on the current Dataset. As an example, the following code specifies that
+ * one of the plan can be broadcasted:
*
* {{{
* df1.join(df2.hint("broadcast"))
@@ -886,19 +918,22 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* df1.hint("rebalance", 10)
* }}}
*
- * @param name the name of the hint
- * @param parameters the parameters of the hint, all the parameters should be a `Column` or
- * `Expression` or `Symbol` or could be converted into a `Literal`
+ * @param name
+ * the name of the hint
+ * @param parameters
+ * the parameters of the hint, all the parameters should be a `Column` or `Expression` or
+ * `Symbol` or could be converted into a `Literal`
* @group basic
* @since 2.2.0
*/
@scala.annotation.varargs
- def hint(name: String, parameters: Any*): DS[T]
+ def hint(name: String, parameters: Any*): Dataset[T]
/**
* Selects column based on the column name and returns it as a [[org.apache.spark.sql.Column]].
*
- * @note The column name can also reference to a nested column like `a.b`.
+ * @note
+ * The column name can also reference to a nested column like `a.b`.
* @group untypedrel
* @since 2.0.0
*/
@@ -907,7 +942,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Selects column based on the column name and returns it as a [[org.apache.spark.sql.Column]].
*
- * @note The column name can also reference to a nested column like `a.b`.
+ * @note
+ * The column name can also reference to a nested column like `a.b`.
* @group untypedrel
* @since 2.0.0
*/
@@ -940,7 +976,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def as(alias: String): DS[T]
+ def as(alias: String): Dataset[T]
/**
* (Scala-specific) Returns a new Dataset with an alias set.
@@ -948,7 +984,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 2.0.0
*/
- def as(alias: Symbol): DS[T] = as(alias.name)
+ def as(alias: Symbol): Dataset[T] = as(alias.name)
/**
* Returns a new Dataset with an alias set. Same as `as`.
@@ -956,7 +992,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 2.0.0
*/
- def alias(alias: String): DS[T] = as(alias)
+ def alias(alias: String): Dataset[T] = as(alias)
/**
* (Scala-specific) Returns a new Dataset with an alias set. Same as `as`.
@@ -964,7 +1000,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 2.0.0
*/
- def alias(alias: Symbol): DS[T] = as(alias)
+ def alias(alias: Symbol): Dataset[T] = as(alias)
/**
* Selects a set of column based expressions.
@@ -976,11 +1012,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def select(cols: Column*): DS[Row]
+ def select(cols: Column*): Dataset[Row]
/**
- * Selects a set of columns. This is a variant of `select` that can only select
- * existing columns using column names (i.e. cannot construct expressions).
+ * Selects a set of columns. This is a variant of `select` that can only select existing columns
+ * using column names (i.e. cannot construct expressions).
*
* {{{
* // The following two are equivalent:
@@ -992,11 +1028,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def select(col: String, cols: String*): DS[Row] = select((col +: cols).map(Column(_)): _*)
+ def select(col: String, cols: String*): Dataset[Row] = select((col +: cols).map(Column(_)): _*)
/**
- * Selects a set of SQL expressions. This is a variant of `select` that accepts
- * SQL expressions.
+ * Selects a set of SQL expressions. This is a variant of `select` that accepts SQL expressions.
*
* {{{
* // The following are equivalent:
@@ -1008,7 +1043,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def selectExpr(exprs: String*): DS[Row] = select(exprs.map(functions.expr): _*)
+ def selectExpr(exprs: String*): Dataset[Row] = select(exprs.map(functions.expr): _*)
/**
* Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expression for
@@ -1022,14 +1057,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def select[U1](c1: TypedColumn[T, U1]): DS[U1]
+ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1]
/**
* Internal helper function for building typed selects that return tuples. For simplicity and
- * code reuse, we do this without the help of the type system and then use helper functions
- * that cast appropriately for the user facing interface.
+ * code reuse, we do this without the help of the type system and then use helper functions that
+ * cast appropriately for the user facing interface.
*/
- protected def selectUntyped(columns: TypedColumn[_, _]*): DS[_]
+ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_]
/**
* Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for
@@ -1038,8 +1073,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): DS[(U1, U2)] =
- selectUntyped(c1, c2).asInstanceOf[DS[(U1, U2)]]
+ def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
+ selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
* Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for
@@ -1051,8 +1086,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def select[U1, U2, U3](
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
- c3: TypedColumn[T, U3]): DS[(U1, U2, U3)] =
- selectUntyped(c1, c2, c3).asInstanceOf[DS[(U1, U2, U3)]]
+ c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
+ selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
* Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for
@@ -1065,8 +1100,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3],
- c4: TypedColumn[T, U4]): DS[(U1, U2, U3, U4)] =
- selectUntyped(c1, c2, c3, c4).asInstanceOf[DS[(U1, U2, U3, U4)]]
+ c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
+ selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
* Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for
@@ -1080,8 +1115,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3],
c4: TypedColumn[T, U4],
- c5: TypedColumn[T, U5]): DS[(U1, U2, U3, U4, U5)] =
- selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[DS[(U1, U2, U3, U4, U5)]]
+ c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
+ selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/**
* Filters rows using the given condition.
@@ -1094,7 +1129,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def filter(condition: Column): DS[T]
+ def filter(condition: Column): Dataset[T]
/**
* Filters rows using the given SQL expression.
@@ -1105,26 +1140,26 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def filter(conditionExpr: String): DS[T] =
+ def filter(conditionExpr: String): Dataset[T] =
filter(functions.expr(conditionExpr))
/**
- * (Scala-specific)
- * Returns a new Dataset that only contains elements where `func` returns `true`.
+ * (Scala-specific) Returns a new Dataset that only contains elements where `func` returns
+ * `true`.
*
* @group typedrel
* @since 1.6.0
*/
- def filter(func: T => Boolean): DS[T]
+ def filter(func: T => Boolean): Dataset[T]
/**
- * (Java-specific)
- * Returns a new Dataset that only contains elements where `func` returns `true`.
+ * (Java-specific) Returns a new Dataset that only contains elements where `func` returns
+ * `true`.
*
* @group typedrel
* @since 1.6.0
*/
- def filter(func: FilterFunction[T]): DS[T]
+ def filter(func: FilterFunction[T]): Dataset[T]
/**
* Filters rows using the given condition. This is an alias for `filter`.
@@ -1137,7 +1172,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def where(condition: Column): DS[T] = filter(condition)
+ def where(condition: Column): Dataset[T] = filter(condition)
/**
* Filters rows using the given SQL expression.
@@ -1148,7 +1183,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def where(conditionExpr: String): DS[T] = filter(conditionExpr)
+ def where(conditionExpr: String): Dataset[T] = filter(conditionExpr)
/**
* Groups the Dataset using the specified columns, so we can run aggregation on them. See
@@ -1169,14 +1204,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def groupBy(cols: Column*): RGD
+ def groupBy(cols: Column*): RelationalGroupedDataset
/**
- * Groups the Dataset using the specified columns, so that we can run aggregation on them.
- * See [[RelationalGroupedDataset]] for all the available aggregate functions.
+ * Groups the Dataset using the specified columns, so that we can run aggregation on them. See
+ * [[RelationalGroupedDataset]] for all the available aggregate functions.
*
- * This is a variant of groupBy that can only group by existing columns using column names
- * (i.e. cannot construct expressions).
+ * This is a variant of groupBy that can only group by existing columns using column names (i.e.
+ * cannot construct expressions).
*
* {{{
* // Compute the average for all numeric columns grouped by department.
@@ -1193,12 +1228,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def groupBy(col1: String, cols: String*): RGD = groupBy((col1 +: cols).map(col): _*)
+ def groupBy(col1: String, cols: String*): RelationalGroupedDataset = groupBy(
+ (col1 +: cols).map(col): _*)
/**
- * Create a multi-dimensional rollup for the current Dataset using the specified columns,
- * so we can run aggregation on them.
- * See [[RelationalGroupedDataset]] for all the available aggregate functions.
+ * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
*
* {{{
* // Compute the average for all numeric columns rolled up by department and group.
@@ -1215,15 +1251,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def rollup(cols: Column*): RGD
+ def rollup(cols: Column*): RelationalGroupedDataset
/**
- * Create a multi-dimensional rollup for the current Dataset using the specified columns,
- * so we can run aggregation on them.
- * See [[RelationalGroupedDataset]] for all the available aggregate functions.
+ * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
*
- * This is a variant of rollup that can only group by existing columns using column names
- * (i.e. cannot construct expressions).
+ * This is a variant of rollup that can only group by existing columns using column names (i.e.
+ * cannot construct expressions).
*
* {{{
* // Compute the average for all numeric columns rolled up by department and group.
@@ -1240,12 +1276,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def rollup(col1: String, cols: String*): RGD = rollup((col1 +: cols).map(col): _*)
+ def rollup(col1: String, cols: String*): RelationalGroupedDataset = rollup(
+ (col1 +: cols).map(col): _*)
/**
- * Create a multi-dimensional cube for the current Dataset using the specified columns,
- * so we can run aggregation on them.
- * See [[RelationalGroupedDataset]] for all the available aggregate functions.
+ * Create a multi-dimensional cube for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
*
* {{{
* // Compute the average for all numeric columns cubed by department and group.
@@ -1262,15 +1299,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def cube(cols: Column*): RGD
+ def cube(cols: Column*): RelationalGroupedDataset
/**
- * Create a multi-dimensional cube for the current Dataset using the specified columns,
- * so we can run aggregation on them.
- * See [[RelationalGroupedDataset]] for all the available aggregate functions.
+ * Create a multi-dimensional cube for the current Dataset using the specified columns, so we
+ * can run aggregation on them. See [[RelationalGroupedDataset]] for all the available aggregate
+ * functions.
*
- * This is a variant of cube that can only group by existing columns using column names
- * (i.e. cannot construct expressions).
+ * This is a variant of cube that can only group by existing columns using column names (i.e.
+ * cannot construct expressions).
*
* {{{
* // Compute the average for all numeric columns cubed by department and group.
@@ -1287,12 +1324,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def cube(col1: String, cols: String*): RGD = cube((col1 +: cols).map(col): _*)
+ def cube(col1: String, cols: String*): RelationalGroupedDataset = cube(
+ (col1 +: cols).map(col): _*)
/**
- * Create multi-dimensional aggregation for the current Dataset using the specified grouping sets,
- * so we can run aggregation on them.
- * See [[RelationalGroupedDataset]] for all the available aggregate functions.
+ * Create multi-dimensional aggregation for the current Dataset using the specified grouping
+ * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]] for all the
+ * available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns group by specific grouping sets.
@@ -1309,7 +1347,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 4.0.0
*/
@scala.annotation.varargs
- def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RGD
+ def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset
/**
* (Scala-specific) Aggregates on the entire Dataset without groups.
@@ -1322,7 +1360,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 2.0.0
*/
- def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] = {
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = {
groupBy().agg(aggExpr, aggExprs: _*)
}
@@ -1337,7 +1375,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 2.0.0
*/
- def agg(exprs: Map[String, String]): DS[Row] = groupBy().agg(exprs)
+ def agg(exprs: Map[String, String]): Dataset[Row] = groupBy().agg(exprs)
/**
* (Java-specific) Aggregates on the entire Dataset without groups.
@@ -1350,7 +1388,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 2.0.0
*/
- def agg(exprs: util.Map[String, String]): DS[Row] = groupBy().agg(exprs)
+ def agg(exprs: util.Map[String, String]): Dataset[Row] = groupBy().agg(exprs)
/**
* Aggregates on the entire Dataset without groups.
@@ -1364,12 +1402,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def agg(expr: Column, exprs: Column*): DS[Row] = groupBy().agg(expr, exprs: _*)
+ def agg(expr: Column, exprs: Column*): Dataset[Row] = groupBy().agg(expr, exprs: _*)
/**
- * (Scala-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given `func`
- * must be commutative and associative or the result may be non-deterministic.
+ * (Scala-specific) Reduces the elements of this Dataset using the specified binary function.
+ * The given `func` must be commutative and associative or the result may be non-deterministic.
*
* @group action
* @since 1.6.0
@@ -1377,24 +1414,45 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def reduce(func: (T, T) => T): T
/**
- * (Java-specific)
- * Reduces the elements of this Dataset using the specified binary function. The given `func`
- * must be commutative and associative or the result may be non-deterministic.
+ * (Java-specific) Reduces the elements of this Dataset using the specified binary function. The
+ * given `func` must be commutative and associative or the result may be non-deterministic.
*
* @group action
* @since 1.6.0
*/
- def reduce(func: ReduceFunction[T]): T = reduce(func.call)
+ def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func))
+
+ /**
+ * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
+ * key `func`.
+ *
+ * @group typedrel
+ * @since 2.0.0
+ */
+ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T]
+
+ /**
+ * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
+ * key `func`.
+ *
+ * @group typedrel
+ * @since 2.0.0
+ */
+ def groupByKey[K](
+ func: MapFunction[T, K],
+ encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = {
+ groupByKey(ToScalaUDF(func))(encoder)
+ }
/**
- * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
- * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed.
*
- * This function is useful to massage a DataFrame into a format where some
- * columns are identifier columns ("ids"), while all other columns ("values")
- * are "unpivoted" to the rows, leaving just two non-id columns, named as given
- * by `variableColumnName` and `valueColumnName`.
+ * This function is useful to massage a DataFrame into a format where some columns are
+ * identifier columns ("ids"), while all other columns ("values") are "unpivoted" to the rows,
+ * leaving just two non-id columns, named as given by `variableColumnName` and
+ * `valueColumnName`.
*
* {{{
* val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long")
@@ -1424,18 +1482,22 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* // |-- value: long (nullable = true)
* }}}
*
- * When no "id" columns are given, the unpivoted DataFrame consists of only the
- * "variable" and "value" columns.
+ * When no "id" columns are given, the unpivoted DataFrame consists of only the "variable" and
+ * "value" columns.
*
* All "value" columns must share a least common data type. Unless they are the same data type,
- * all "value" columns are cast to the nearest common data type. For instance,
- * types `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType`
- * do not have a common data type and `unpivot` fails with an `AnalysisException`.
- *
- * @param ids Id columns
- * @param values Value columns to unpivot
- * @param variableColumnName Name of the variable column
- * @param valueColumnName Name of the value column
+ * all "value" columns are cast to the nearest common data type. For instance, types
+ * `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType` do
+ * not have a common data type and `unpivot` fails with an `AnalysisException`.
+ *
+ * @param ids
+ * Id columns
+ * @param values
+ * Value columns to unpivot
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
* @group untypedrel
* @since 3.4.0
*/
@@ -1443,21 +1505,25 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
ids: Array[Column],
values: Array[Column],
variableColumnName: String,
- valueColumnName: String): DS[Row]
+ valueColumnName: String): Dataset[Row]
/**
- * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
- * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed.
*
- * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
*
- * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)`
- * where `values` is set to all non-id columns that exist in the DataFrame.
+ * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values`
+ * is set to all non-id columns that exist in the DataFrame.
*
- * @param ids Id columns
- * @param variableColumnName Name of the variable column
- * @param valueColumnName Name of the value column
+ * @param ids
+ * Id columns
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
*
* @group untypedrel
* @since 3.4.0
@@ -1465,18 +1531,23 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def unpivot(
ids: Array[Column],
variableColumnName: String,
- valueColumnName: String): DS[Row]
+ valueColumnName: String): Dataset[Row]
/**
- * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
- * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed. This is an alias for `unpivot`.
*
- * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
- * @param ids Id columns
- * @param values Value columns to unpivot
- * @param variableColumnName Name of the variable column
- * @param valueColumnName Name of the value column
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ * @param ids
+ * Id columns
+ * @param values
+ * Value columns to unpivot
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
* @group untypedrel
* @since 3.4.0
*/
@@ -1484,53 +1555,135 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
ids: Array[Column],
values: Array[Column],
variableColumnName: String,
- valueColumnName: String): DS[Row] =
+ valueColumnName: String): Dataset[Row] =
unpivot(ids, values, variableColumnName, valueColumnName)
/**
- * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
- * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
+ * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
+ * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed. This is an alias for `unpivot`.
*
- * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
- *
- * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)`
- * where `values` is set to all non-id columns that exist in the DataFrame.
- * @param ids Id columns
- * @param variableColumnName Name of the variable column
- * @param valueColumnName Name of the value column
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ *
+ * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` where `values`
+ * is set to all non-id columns that exist in the DataFrame.
+ * @param ids
+ * Id columns
+ * @param variableColumnName
+ * Name of the variable column
+ * @param valueColumnName
+ * Name of the value column
* @group untypedrel
* @since 3.4.0
*/
def melt(
ids: Array[Column],
variableColumnName: String,
- valueColumnName: String): DS[Row] =
+ valueColumnName: String): Dataset[Row] =
unpivot(ids, variableColumnName, valueColumnName)
- /**
- * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset
- * that returns the same result as the input, with the following guarantees:
- *
- *
It will compute the defined aggregates (metrics) on all the data that is flowing through
- * the Dataset at that point.
- *
It will report the value of the defined aggregate columns as soon as we reach a completion
- * point. A completion point is either the end of a query (batch mode) or the end of a streaming
- * epoch. The value of the aggregates only reflects the data processed since the previous
- * completion point.
- *
- * Please note that continuous execution is currently not supported.
- *
- * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or
- * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that
- * contain references to the input Dataset's columns must always be wrapped in an aggregate
- * function.
- *
- * @group typedrel
- * @since 3.0.0
- */
+ /**
+ * Transposes a DataFrame such that the values in the specified index column become the new
+ * columns of the DataFrame.
+ *
+ * Please note:
+ * - All columns except the index column must share a least common data type. Unless they are
+ * the same data type, all columns are cast to the nearest common data type.
+ * - The name of the column into which the original column names are transposed defaults to
+ * "key".
+ * - null values in the index column are excluded from the column names for the transposed
+ * table, which are ordered in ascending order.
+ *
+ * {{{
+ * val df = Seq(("A", 1, 2), ("B", 3, 4)).toDF("id", "val1", "val2")
+ * df.show()
+ * // output:
+ * // +---+----+----+
+ * // | id|val1|val2|
+ * // +---+----+----+
+ * // | A| 1| 2|
+ * // | B| 3| 4|
+ * // +---+----+----+
+ *
+ * df.transpose($"id").show()
+ * // output:
+ * // +----+---+---+
+ * // | key| A| B|
+ * // +----+---+---+
+ * // |val1| 1| 3|
+ * // |val2| 2| 4|
+ * // +----+---+---+
+ * // schema:
+ * // root
+ * // |-- key: string (nullable = false)
+ * // |-- A: integer (nullable = true)
+ * // |-- B: integer (nullable = true)
+ *
+ * df.transpose().show()
+ * // output:
+ * // +----+---+---+
+ * // | key| A| B|
+ * // +----+---+---+
+ * // |val1| 1| 3|
+ * // |val2| 2| 4|
+ * // +----+---+---+
+ * // schema:
+ * // root
+ * // |-- key: string (nullable = false)
+ * // |-- A: integer (nullable = true)
+ * // |-- B: integer (nullable = true)
+ * }}}
+ *
+ * @param indexColumn
+ * The single column that will be treated as the index for the transpose operation. This
+ * column will be used to pivot the data, transforming the DataFrame such that the values of
+ * the indexColumn become the new columns in the transposed DataFrame.
+ *
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ def transpose(indexColumn: Column): Dataset[Row]
+
+ /**
+ * Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame
+ * such that the values in the first column become the new columns of the DataFrame.
+ *
+ * This is equivalent to calling `Dataset#transpose(Column)` where `indexColumn` is set to the
+ * first column.
+ *
+ * Please note:
+ * - All columns except the index column must share a least common data type. Unless they are
+ * the same data type, all columns are cast to the nearest common data type.
+ * - The name of the column into which the original column names are transposed defaults to
+ * "key".
+ * - Non-"key" column names for the transposed table are ordered in ascending order.
+ *
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ def transpose(): Dataset[Row]
+
+ /**
+ * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset
+ * that returns the same result as the input, with the following guarantees:
It will
+ * compute the defined aggregates (metrics) on all the data that is flowing through the Dataset
+ * at that point.
It will report the value of the defined aggregate columns as soon as
+ * we reach a completion point. A completion point is either the end of a query (batch mode) or
+ * the end of a streaming epoch. The value of the aggregates only reflects the data processed
+ * since the previous completion point.
Please note that continuous execution is
+ * currently not supported.
+ *
+ * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or
+ * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that
+ * contain references to the input Dataset's columns must always be wrapped in an aggregate
+ * function.
+ *
+ * @group typedrel
+ * @since 3.0.0
+ */
@scala.annotation.varargs
- def observe(name: String, expr: Column, exprs: Column*): DS[T]
+ def observe(name: String, expr: Column, exprs: Column*): Dataset[T]
/**
* Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This method
@@ -1546,23 +1699,24 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* val metrics = observation.get
* }}}
*
- * @throws IllegalArgumentException If this is a streaming Dataset (this.isStreaming == true)
+ * @throws IllegalArgumentException
+ * If this is a streaming Dataset (this.isStreaming == true)
*
* @group typedrel
* @since 3.3.0
*/
@scala.annotation.varargs
- def observe(observation: Observation, expr: Column, exprs: Column*): DS[T]
+ def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T]
/**
- * Returns a new Dataset by taking the first `n` rows. The difference between this function
- * and `head` is that `head` is an action and returns an array (by triggering query execution)
- * while `limit` returns a new Dataset.
+ * Returns a new Dataset by taking the first `n` rows. The difference between this function and
+ * `head` is that `head` is an action and returns an array (by triggering query execution) while
+ * `limit` returns a new Dataset.
*
* @group typedrel
* @since 2.0.0
*/
- def limit(n: Int): DS[T]
+ def limit(n: Int): Dataset[T]
/**
* Returns a new Dataset by skipping the first `n` rows.
@@ -1570,7 +1724,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 3.4.0
*/
- def offset(n: Int): DS[T]
+ def offset(n: Int): Dataset[T]
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
@@ -1594,19 +1748,19 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* // +----+----+----+
* }}}
*
- * Notice that the column positions in the schema aren't necessarily matched with the
- * fields in the strongly typed objects in a Dataset. This function resolves columns
- * by their positions in the schema, not the fields in the strongly typed objects. Use
- * [[unionByName]] to resolve columns by field name in the typed objects.
+ * Notice that the column positions in the schema aren't necessarily matched with the fields in
+ * the strongly typed objects in a Dataset. This function resolves columns by their positions in
+ * the schema, not the fields in the strongly typed objects. Use [[unionByName]] to resolve
+ * columns by field name in the typed objects.
*
* @group typedrel
* @since 2.0.0
*/
- def union(other: DS[T]): DS[T]
+ def union(other: DS[T]): Dataset[T]
/**
- * Returns a new Dataset containing union of rows in this Dataset and another Dataset.
- * This is an alias for `union`.
+ * Returns a new Dataset containing union of rows in this Dataset and another Dataset. This is
+ * an alias for `union`.
*
* This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
* deduplication of elements), use this function followed by a [[distinct]].
@@ -1616,7 +1770,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 2.0.0
*/
- def unionAll(other: DS[T]): DS[T] = union(other)
+ def unionAll(other: DS[T]): Dataset[T] = union(other)
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
@@ -1624,8 +1778,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set
* union (that does deduplication of elements), use this function followed by a [[distinct]].
*
- * The difference between this function and [[union]] is that this function
- * resolves columns by name (not by position):
+ * The difference between this function and [[union]] is that this function resolves columns by
+ * name (not by position):
*
* {{{
* val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
@@ -1647,18 +1801,17 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 2.3.0
*/
- def unionByName(other: DS[T]): DS[T] = unionByName(other, allowMissingColumns = false)
+ def unionByName(other: DS[T]): Dataset[T] = unionByName(other, allowMissingColumns = false)
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
*
- * The difference between this function and [[union]] is that this function
- * resolves columns by name (not by position).
+ * The difference between this function and [[union]] is that this function resolves columns by
+ * name (not by position).
*
- * When the parameter `allowMissingColumns` is `true`, the set of column names
- * in this and other `Dataset` can differ; missing columns will be filled with null.
- * Further, the missing columns of this `Dataset` will be added at the end
- * in the schema of the union result:
+ * When the parameter `allowMissingColumns` is `true`, the set of column names in this and other
+ * `Dataset` can differ; missing columns will be filled with null. Further, the missing columns
+ * of this `Dataset` will be added at the end in the schema of the union result:
*
* {{{
* val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
@@ -1692,150 +1845,169 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 3.1.0
*/
- def unionByName(other: DS[T], allowMissingColumns: Boolean): DS[T]
+ def unionByName(other: DS[T], allowMissingColumns: Boolean): Dataset[T]
/**
- * Returns a new Dataset containing rows only in both this Dataset and another Dataset.
- * This is equivalent to `INTERSECT` in SQL.
+ * Returns a new Dataset containing rows only in both this Dataset and another Dataset. This is
+ * equivalent to `INTERSECT` in SQL.
*
- * @note Equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`.
* @group typedrel
* @since 1.6.0
*/
- def intersect(other: DS[T]): DS[T]
+ def intersect(other: DS[T]): Dataset[T]
/**
* Returns a new Dataset containing rows only in both this Dataset and another Dataset while
- * preserving the duplicates.
- * This is equivalent to `INTERSECT ALL` in SQL.
+ * preserving the duplicates. This is equivalent to `INTERSECT ALL` in SQL.
*
- * @note Equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`. Also as standard
- * in SQL, this function resolves columns by position (not by name).
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this
+ * function resolves columns by position (not by name).
* @group typedrel
* @since 2.4.0
*/
- def intersectAll(other: DS[T]): DS[T]
+ def intersectAll(other: DS[T]): Dataset[T]
/**
- * Returns a new Dataset containing rows in this Dataset but not in another Dataset.
- * This is equivalent to `EXCEPT DISTINCT` in SQL.
+ * Returns a new Dataset containing rows in this Dataset but not in another Dataset. This is
+ * equivalent to `EXCEPT DISTINCT` in SQL.
*
- * @note Equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`.
* @group typedrel
* @since 2.0.0
*/
- def except(other: DS[T]): DS[T]
+ def except(other: DS[T]): Dataset[T]
/**
* Returns a new Dataset containing rows in this Dataset but not in another Dataset while
- * preserving the duplicates.
- * This is equivalent to `EXCEPT ALL` in SQL.
+ * preserving the duplicates. This is equivalent to `EXCEPT ALL` in SQL.
*
- * @note Equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`. Also as standard
- * in SQL, this function resolves columns by position (not by name).
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`. Also as standard in SQL, this
+ * function resolves columns by position (not by name).
* @group typedrel
* @since 2.4.0
*/
- def exceptAll(other: DS[T]): DS[T]
+ def exceptAll(other: DS[T]): Dataset[T]
/**
- * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
- * using a user-supplied seed.
+ * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a
+ * user-supplied seed.
*
- * @param fraction Fraction of rows to generate, range [0.0, 1.0].
- * @param seed Seed for sampling.
- * @note This is NOT guaranteed to provide exactly the fraction of the count
- * of the given [[Dataset]].
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
+ * @param seed
+ * Seed for sampling.
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the count of the given
+ * [[Dataset]].
* @group typedrel
* @since 2.3.0
*/
- def sample(fraction: Double, seed: Long): DS[T] = {
+ def sample(fraction: Double, seed: Long): Dataset[T] = {
sample(withReplacement = false, fraction = fraction, seed = seed)
}
/**
- * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
- * using a random seed.
+ * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a
+ * random seed.
*
- * @param fraction Fraction of rows to generate, range [0.0, 1.0].
- * @note This is NOT guaranteed to provide exactly the fraction of the count
- * of the given [[Dataset]].
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the count of the given
+ * [[Dataset]].
* @group typedrel
* @since 2.3.0
*/
- def sample(fraction: Double): DS[T] = {
+ def sample(fraction: Double): Dataset[T] = {
sample(withReplacement = false, fraction = fraction)
}
/**
* Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed.
*
- * @param withReplacement Sample with replacement or not.
- * @param fraction Fraction of rows to generate, range [0.0, 1.0].
- * @param seed Seed for sampling.
- * @note This is NOT guaranteed to provide exactly the fraction of the count
- * of the given [[Dataset]].
+ * @param withReplacement
+ * Sample with replacement or not.
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
+ * @param seed
+ * Seed for sampling.
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the count of the given
+ * [[Dataset]].
* @group typedrel
* @since 1.6.0
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Long): DS[T]
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T]
/**
* Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed.
*
- * @param withReplacement Sample with replacement or not.
- * @param fraction Fraction of rows to generate, range [0.0, 1.0].
+ * @param withReplacement
+ * Sample with replacement or not.
+ * @param fraction
+ * Fraction of rows to generate, range [0.0, 1.0].
*
- * @note This is NOT guaranteed to provide exactly the fraction of the total count
- * of the given [[Dataset]].
+ * @note
+ * This is NOT guaranteed to provide exactly the fraction of the total count of the given
+ * [[Dataset]].
*
* @group typedrel
* @since 1.6.0
*/
- def sample(withReplacement: Boolean, fraction: Double): DS[T] = {
+ def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = {
sample(withReplacement, fraction, SparkClassUtils.random.nextLong)
}
/**
* Randomly splits this Dataset with the provided weights.
*
- * @param weights weights for splits, will be normalized if they don't sum to 1.
- * @param seed Seed for sampling.
+ * @param weights
+ * weights for splits, will be normalized if they don't sum to 1.
+ * @param seed
+ * Seed for sampling.
*
* For Java API, use [[randomSplitAsList]].
*
* @group typedrel
* @since 2.0.0
*/
- def randomSplit(weights: Array[Double], seed: Long): Array[_ <: DS[T]]
+ def randomSplit(weights: Array[Double], seed: Long): Array[_ <: Dataset[T]]
/**
* Returns a Java list that contains randomly split Dataset with the provided weights.
*
- * @param weights weights for splits, will be normalized if they don't sum to 1.
- * @param seed Seed for sampling.
+ * @param weights
+ * weights for splits, will be normalized if they don't sum to 1.
+ * @param seed
+ * Seed for sampling.
* @group typedrel
* @since 2.0.0
*/
- def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: DS[T]]
+ def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: Dataset[T]]
/**
* Randomly splits this Dataset with the provided weights.
*
- * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param weights
+ * weights for splits, will be normalized if they don't sum to 1.
* @group typedrel
* @since 2.0.0
*/
- def randomSplit(weights: Array[Double]): Array[_ <: DS[T]]
+ def randomSplit(weights: Array[Double]): Array[_ <: Dataset[T]]
/**
- * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more
- * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
- * the input row are implicitly joined with each row that is output by the function.
+ * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows
+ * by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the
+ * input row are implicitly joined with each row that is output by the function.
*
* Given that this is deprecated, as an alternative, you can explode columns either using
* `functions.explode()` or `flatMap()`. The following example uses these alternatives to count
@@ -1843,7 +2015,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
*
* {{{
* case class Book(title: String, words: String)
- * val ds: DS[Book]
+ * val ds: Dataset[Book]
*
* val allWords = ds.select($"title", explode(split($"words", " ")).as("word"))
*
@@ -1860,12 +2032,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0")
- def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DS[Row]
+ def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): Dataset[Row]
/**
- * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero
- * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
- * columns of the input row are implicitly joined with each value that is output by the function.
+ * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or
+ * more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
+ * columns of the input row are implicitly joined with each value that is output by the
+ * function.
*
* Given that this is deprecated, as an alternative, you can explode columns either using
* `functions.explode()`:
@@ -1884,24 +2057,25 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0")
- def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => IterableOnce[B])
- : DS[Row]
+ def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(
+ f: A => IterableOnce[B]): Dataset[Row]
/**
- * Returns a new Dataset by adding a column or replacing the existing column that has
- * the same name.
+ * Returns a new Dataset by adding a column or replacing the existing column that has the same
+ * name.
*
- * `column`'s expression must only refer to attributes supplied by this Dataset. It is an
- * error to add a column that refers to some other Dataset.
+ * `column`'s expression must only refer to attributes supplied by this Dataset. It is an error
+ * to add a column that refers to some other Dataset.
*
- * @note this method introduces a projection internally. Therefore, calling it multiple times,
- * for instance, via loops in order to add multiple columns can generate big plans which
- * can cause performance issues and even `StackOverflowException`. To avoid this,
- * use `select` with the multiple columns at once.
+ * @note
+ * this method introduces a projection internally. Therefore, calling it multiple times, for
+ * instance, via loops in order to add multiple columns can generate big plans which can cause
+ * performance issues and even `StackOverflowException`. To avoid this, use `select` with the
+ * multiple columns at once.
* @group untypedrel
* @since 2.0.0
*/
- def withColumn(colName: String, col: Column): DS[Row] = withColumns(Seq(colName), Seq(col))
+ def withColumn(colName: String, col: Column): Dataset[Row] = withColumns(Seq(colName), Seq(col))
/**
* (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns
@@ -1913,7 +2087,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 3.3.0
*/
- def withColumns(colsMap: Map[String, Column]): DS[Row] = {
+ def withColumns(colsMap: Map[String, Column]): Dataset[Row] = {
val (colNames, newCols) = colsMap.toSeq.unzip
withColumns(colNames, newCols)
}
@@ -1928,60 +2102,55 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 3.3.0
*/
- def withColumns(colsMap: util.Map[String, Column]): DS[Row] = withColumns(
- colsMap.asScala.toMap
- )
+ def withColumns(colsMap: util.Map[String, Column]): Dataset[Row] = withColumns(
+ colsMap.asScala.toMap)
/**
- * Returns a new Dataset by adding columns or replacing the existing columns that has
- * the same names.
+ * Returns a new Dataset by adding columns or replacing the existing columns that has the same
+ * names.
*/
- protected def withColumns(colNames: Seq[String], cols: Seq[Column]): DS[Row]
+ protected def withColumns(colNames: Seq[String], cols: Seq[Column]): Dataset[Row]
/**
- * Returns a new Dataset with a column renamed.
- * This is a no-op if schema doesn't contain existingName.
+ * Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain
+ * existingName.
*
* @group untypedrel
* @since 2.0.0
*/
- def withColumnRenamed(existingName: String, newName: String): DS[Row] =
+ def withColumnRenamed(existingName: String, newName: String): Dataset[Row] =
withColumnsRenamed(Seq(existingName), Seq(newName))
/**
- * (Scala-specific)
- * Returns a new Dataset with a columns renamed.
- * This is a no-op if schema doesn't contain existingName.
+ * (Scala-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema
+ * doesn't contain existingName.
*
* `colsMap` is a map of existing column name and new column name.
*
- * @throws org.apache.spark.sql.AnalysisException if there are duplicate names in resulting
- * projection
+ * @throws org.apache.spark.sql.AnalysisException
+ * if there are duplicate names in resulting projection
* @group untypedrel
* @since 3.4.0
*/
@throws[AnalysisException]
- def withColumnsRenamed(colsMap: Map[String, String]): DS[Row] = {
+ def withColumnsRenamed(colsMap: Map[String, String]): Dataset[Row] = {
val (colNames, newColNames) = colsMap.toSeq.unzip
withColumnsRenamed(colNames, newColNames)
}
/**
- * (Java-specific)
- * Returns a new Dataset with a columns renamed.
- * This is a no-op if schema doesn't contain existingName.
+ * (Java-specific) Returns a new Dataset with a columns renamed. This is a no-op if schema
+ * doesn't contain existingName.
*
* `colsMap` is a map of existing column name and new column name.
*
* @group untypedrel
* @since 3.4.0
*/
- def withColumnsRenamed(colsMap: util.Map[String, String]): DS[Row] =
+ def withColumnsRenamed(colsMap: util.Map[String, String]): Dataset[Row] =
withColumnsRenamed(colsMap.asScala.toMap)
- protected def withColumnsRenamed(
- colNames: Seq[String],
- newColNames: Seq[String]): DS[Row]
+ protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): Dataset[Row]
/**
* Returns a new Dataset by updating an existing column with metadata.
@@ -1989,17 +2158,17 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 3.3.0
*/
- def withMetadata(columnName: String, metadata: Metadata): DS[Row]
+ def withMetadata(columnName: String, metadata: Metadata): Dataset[Row]
/**
- * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain
- * column name.
+ * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column
+ * name.
*
* This method can only be used to drop top level columns. the colName string is treated
* literally without further interpretation.
*
- * Note: `drop(colName)` has different semantic with `drop(col(colName))`, for example:
- * 1, multi column have the same colName:
+ * Note: `drop(colName)` has different semantic with `drop(col(colName))`, for example: 1, multi
+ * column have the same colName:
* {{{
* val df1 = spark.range(0, 2).withColumn("key1", lit(1))
* val df2 = spark.range(0, 2).withColumn("key2", lit(2))
@@ -2062,111 +2231,108 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group untypedrel
* @since 2.0.0
*/
- def drop(colName: String): DS[Row] = drop(colName :: Nil : _*)
+ def drop(colName: String): Dataset[Row] = drop(colName :: Nil: _*)
/**
- * Returns a new Dataset with columns dropped.
- * This is a no-op if schema doesn't contain column name(s).
+ * Returns a new Dataset with columns dropped. This is a no-op if schema doesn't contain column
+ * name(s).
*
- * This method can only be used to drop top level columns. the colName string is treated literally
- * without further interpretation.
+ * This method can only be used to drop top level columns. the colName string is treated
+ * literally without further interpretation.
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
- def drop(colNames: String*): DS[Row]
+ def drop(colNames: String*): Dataset[Row]
/**
* Returns a new Dataset with column dropped.
*
- * This method can only be used to drop top level column.
- * This version of drop accepts a [[org.apache.spark.sql.Column]] rather than a name.
- * This is a no-op if the Dataset doesn't have a column
- * with an equivalent expression.
+ * This method can only be used to drop top level column. This version of drop accepts a
+ * [[org.apache.spark.sql.Column]] rather than a name. This is a no-op if the Dataset doesn't
+ * have a column with an equivalent expression.
*
- * Note: `drop(col(colName))` has different semantic with `drop(colName)`,
- * please refer to `Dataset#drop(colName: String)`.
+ * Note: `drop(col(colName))` has different semantic with `drop(colName)`, please refer to
+ * `Dataset#drop(colName: String)`.
*
* @group untypedrel
* @since 2.0.0
*/
- def drop(col: Column): DS[Row] = drop(col, Nil : _*)
+ def drop(col: Column): Dataset[Row] = drop(col, Nil: _*)
/**
* Returns a new Dataset with columns dropped.
*
- * This method can only be used to drop top level columns.
- * This is a no-op if the Dataset doesn't have a columns
- * with an equivalent expression.
+ * This method can only be used to drop top level columns. This is a no-op if the Dataset
+ * doesn't have a columns with an equivalent expression.
*
* @group untypedrel
* @since 3.4.0
*/
@scala.annotation.varargs
- def drop(col: Column, cols: Column*): DS[Row]
+ def drop(col: Column, cols: Column*): Dataset[Row]
/**
- * Returns a new Dataset that contains only the unique rows from this Dataset.
- * This is an alias for `distinct`.
+ * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias
+ * for `distinct`.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
- * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
- * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly
+ * limit the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
- def dropDuplicates(): DS[T]
+ def dropDuplicates(): Dataset[T]
/**
- * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only
- * the subset of columns.
+ * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the
+ * subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
- * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
- * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly
+ * limit the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
- def dropDuplicates(colNames: Seq[String]): DS[T]
+ def dropDuplicates(colNames: Seq[String]): Dataset[T]
/**
- * Returns a new Dataset with duplicate rows removed, considering only
- * the subset of columns.
+ * Returns a new Dataset with duplicate rows removed, considering only the subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
- * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
- * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly
+ * limit the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
- def dropDuplicates(colNames: Array[String]): DS[T] =
+ def dropDuplicates(colNames: Array[String]): Dataset[T] =
dropDuplicates(colNames.toImmutableArraySeq)
/**
- * Returns a new [[Dataset]] with duplicate rows removed, considering only
- * the subset of columns.
+ * Returns a new [[Dataset]] with duplicate rows removed, considering only the subset of
+ * columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
- * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
- * the state. In addition, too late data older than watermark will be dropped to avoid any
+ * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly
+ * limit the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
- def dropDuplicates(col1: String, cols: String*): DS[T] = {
+ def dropDuplicates(col1: String, cols: String*): Dataset[T] = {
val colNames: Seq[String] = col1 +: cols
dropDuplicates(colNames)
}
@@ -2177,8 +2343,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
- * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
- * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
+ * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to
+ * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
@@ -2188,7 +2354,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 3.5.0
*/
- def dropDuplicatesWithinWatermark(): DS[T]
+ def dropDuplicatesWithinWatermark(): Dataset[T]
/**
* Returns a new Dataset with duplicates rows removed, considering only the subset of columns,
@@ -2197,8 +2363,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
- * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
- * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
+ * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to
+ * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
@@ -2208,7 +2374,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 3.5.0
*/
- def dropDuplicatesWithinWatermark(colNames: Seq[String]): DS[T]
+ def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T]
/**
* Returns a new Dataset with duplicates rows removed, considering only the subset of columns,
@@ -2217,8 +2383,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
- * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
- * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
+ * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to
+ * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
@@ -2228,7 +2394,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 3.5.0
*/
- def dropDuplicatesWithinWatermark(colNames: Array[String]): DS[T] = {
+ def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = {
dropDuplicatesWithinWatermark(colNames.toImmutableArraySeq)
}
@@ -2239,8 +2405,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
- * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
- * to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
+ * For a streaming [[Dataset]], this will keep all data across triggers as intermediate state to
+ * drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
@@ -2251,7 +2417,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 3.5.0
*/
@scala.annotation.varargs
- def dropDuplicatesWithinWatermark(col1: String, cols: String*): DS[T] = {
+ def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = {
val colNames: Seq[String] = col1 +: cols
dropDuplicatesWithinWatermark(colNames)
}
@@ -2279,28 +2445,22 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
*
* Use [[summary]] for expanded statistics and control over which statistics to compute.
*
- * @param cols Columns to compute statistics on.
+ * @param cols
+ * Columns to compute statistics on.
* @group action
* @since 1.6.0
*/
@scala.annotation.varargs
- def describe(cols: String*): DS[Row]
-
- /**
- * Computes specified statistics for numeric and string columns. Available statistics are:
- *
- *
count
- *
mean
- *
stddev
- *
min
- *
max
- *
arbitrary approximate percentiles specified as a percentage (e.g. 75%)
- *
count_distinct
- *
approx_count_distinct
- *
+ def describe(cols: String*): Dataset[Row]
+
+ /**
+ * Computes specified statistics for numeric and string columns. Available statistics are:
+ *
count
mean
stddev
min
max
arbitrary
+ * approximate percentiles specified as a percentage (e.g. 75%)
count_distinct
+ *
approx_count_distinct
*
- * If no statistics are given, this function computes count, mean, stddev, min,
- * approximate quartiles (percentiles at 25%, 50%, and 75%), and max.
+ * If no statistics are given, this function computes count, mean, stddev, min, approximate
+ * quartiles (percentiles at 25%, 50%, and 75%), and max.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
@@ -2355,18 +2515,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
*
* See also [[describe]] for basic statistics.
*
- * @param statistics Statistics from above list to be computed.
+ * @param statistics
+ * Statistics from above list to be computed.
* @group action
* @since 2.3.0
*/
@scala.annotation.varargs
- def summary(statistics: String*): DS[Row]
+ def summary(statistics: String*): Dataset[Row]
/**
* Returns the first `n` rows.
*
- * @note this method should only be used if the resulting array is expected to be small, as
- * all the data is loaded into the driver's memory.
+ * @note
+ * this method should only be used if the resulting array is expected to be small, as all the
+ * data is loaded into the driver's memory.
* @group action
* @since 1.6.0
*/
@@ -2391,7 +2553,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Concise syntax for chaining custom transformations.
* {{{
- * def featurize(ds: DS[T]): DS[U] = ...
+ * def featurize(ds: Dataset[T]): Dataset[U] = ...
*
* ds
* .transform(featurize)
@@ -2401,66 +2563,64 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def transform[U](t: DS[T] => DS[U]): DS[U] = t(this.asInstanceOf[DS[T]])
+ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this.asInstanceOf[Dataset[T]])
/**
- * (Scala-specific)
- * Returns a new Dataset that contains the result of applying `func` to each element.
+ * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each
+ * element.
*
* @group typedrel
* @since 1.6.0
*/
- def map[U: Encoder](func: T => U): DS[U]
+ def map[U: Encoder](func: T => U): Dataset[U]
/**
- * (Java-specific)
- * Returns a new Dataset that contains the result of applying `func` to each element.
+ * (Java-specific) Returns a new Dataset that contains the result of applying `func` to each
+ * element.
*
* @group typedrel
* @since 1.6.0
*/
- def map[U](func: MapFunction[T, U], encoder: Encoder[U]): DS[U]
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U]
/**
- * (Scala-specific)
- * Returns a new Dataset that contains the result of applying `func` to each partition.
+ * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each
+ * partition.
*
* @group typedrel
* @since 1.6.0
*/
- def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): DS[U]
+ def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U]
/**
- * (Java-specific)
- * Returns a new Dataset that contains the result of applying `f` to each partition.
+ * (Java-specific) Returns a new Dataset that contains the result of applying `f` to each
+ * partition.
*
* @group typedrel
* @since 1.6.0
*/
- def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): DS[U]
+ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ mapPartitions(ToScalaUDF(f))(encoder)
/**
- * (Scala-specific)
- * Returns a new Dataset by first applying a function to all elements of this Dataset,
- * and then flattening the results.
+ * (Scala-specific) Returns a new Dataset by first applying a function to all elements of this
+ * Dataset, and then flattening the results.
*
* @group typedrel
* @since 1.6.0
*/
- def flatMap[U: Encoder](func: T => IterableOnce[U]): DS[U] =
- mapPartitions(_.flatMap(func))
+ def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] =
+ mapPartitions(UDFAdaptors.flatMapToMapPartitions[T, U](func))
/**
- * (Java-specific)
- * Returns a new Dataset by first applying a function to all elements of this Dataset,
- * and then flattening the results.
+ * (Java-specific) Returns a new Dataset by first applying a function to all elements of this
+ * Dataset, and then flattening the results.
*
* @group typedrel
* @since 1.6.0
*/
- def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): DS[U] = {
- val func: T => Iterator[U] = x => f.call(x).asScala
- flatMap(func)(encoder)
+ def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ mapPartitions(UDFAdaptors.flatMapToMapPartitions(f))(encoder)
}
/**
@@ -2469,16 +2629,19 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group action
* @since 1.6.0
*/
- def foreach(f: T => Unit): Unit
+ def foreach(f: T => Unit): Unit = {
+ foreachPartition(UDFAdaptors.foreachToForeachPartition(f))
+ }
/**
- * (Java-specific)
- * Runs `func` on each element of this Dataset.
+ * (Java-specific) Runs `func` on each element of this Dataset.
*
* @group action
* @since 1.6.0
*/
- def foreach(func: ForeachFunction[T]): Unit = foreach(func.call)
+ def foreach(func: ForeachFunction[T]): Unit = {
+ foreachPartition(UDFAdaptors.foreachToForeachPartition(func))
+ }
/**
* Applies a function `f` to each partition of this Dataset.
@@ -2489,21 +2652,20 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def foreachPartition(f: Iterator[T] => Unit): Unit
/**
- * (Java-specific)
- * Runs `func` on each partition of this Dataset.
+ * (Java-specific) Runs `func` on each partition of this Dataset.
*
* @group action
* @since 1.6.0
*/
def foreachPartition(func: ForeachPartitionFunction[T]): Unit = {
- foreachPartition((it: Iterator[T]) => func.call(it.asJava))
+ foreachPartition(ToScalaUDF(func))
}
/**
* Returns the first `n` rows in the Dataset.
*
- * Running take requires moving data into the application's driver process, and doing so with
- * a very large `n` can crash the driver process with OutOfMemoryError.
+ * Running take requires moving data into the application's driver process, and doing so with a
+ * very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.6.0
@@ -2513,8 +2675,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Returns the last `n` rows in the Dataset.
*
- * Running tail requires moving data into the application's driver process, and doing so with
- * a very large `n` can crash the driver process with OutOfMemoryError.
+ * Running tail requires moving data into the application's driver process, and doing so with a
+ * very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 3.0.0
@@ -2524,8 +2686,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Returns the first `n` rows in the Dataset as a list.
*
- * Running take requires moving data into the application's driver process, and doing so with
- * a very large `n` can crash the driver process with OutOfMemoryError.
+ * Running take requires moving data into the application's driver process, and doing so with a
+ * very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.6.0
@@ -2535,8 +2697,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Returns an array that contains all rows in this Dataset.
*
- * Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ * Running collect requires moving all the data into the application's driver process, and doing
+ * so on a very large dataset can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
*
@@ -2548,8 +2710,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Returns a Java list that contains all rows in this Dataset.
*
- * Running collect requires moving all the data into the application's driver process, and
- * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
+ * Running collect requires moving all the data into the application's driver process, and doing
+ * so on a very large dataset can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.6.0
@@ -2561,9 +2723,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
*
* The iterator will consume as much memory as the largest partition in this Dataset.
*
- * @note this results in multiple Spark jobs, and if the input Dataset is the result
- * of a wide transformation (e.g. join with different partitioners), to avoid
- * recomputing the input Dataset should be cached first.
+ * @note
+ * this results in multiple Spark jobs, and if the input Dataset is the result of a wide
+ * transformation (e.g. join with different partitioners), to avoid recomputing the input
+ * Dataset should be cached first.
* @group action
* @since 2.0.0
*/
@@ -2583,15 +2746,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def repartition(numPartitions: Int): DS[T]
+ def repartition(numPartitions: Int): Dataset[T]
protected def repartitionByExpression(
numPartitions: Option[Int],
- partitionExprs: Seq[Column]): DS[T]
+ partitionExprs: Seq[Column]): Dataset[T]
/**
- * Returns a new Dataset partitioned by the given partitioning expressions into
- * `numPartitions`. The resulting Dataset is hash partitioned.
+ * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`.
+ * The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
@@ -2599,14 +2762,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def repartition(numPartitions: Int, partitionExprs: Column*): DS[T] = {
+ def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
repartitionByExpression(Some(numPartitions), partitionExprs)
}
/**
* Returns a new Dataset partitioned by the given partitioning expressions, using
- * `spark.sql.shuffle.partitions` as number of partitions.
- * The resulting Dataset is hash partitioned.
+ * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash
+ * partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
@@ -2614,92 +2777,88 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 2.0.0
*/
@scala.annotation.varargs
- def repartition(partitionExprs: Column*): DS[T] = {
+ def repartition(partitionExprs: Column*): Dataset[T] = {
repartitionByExpression(None, partitionExprs)
}
protected def repartitionByRange(
numPartitions: Option[Int],
- partitionExprs: Seq[Column]): DS[T]
-
+ partitionExprs: Seq[Column]): Dataset[T]
/**
- * Returns a new Dataset partitioned by the given partitioning expressions into
- * `numPartitions`. The resulting Dataset is range partitioned.
- *
- * At least one partition-by expression must be specified.
- * When no explicit sort order is specified, "ascending nulls first" is assumed.
- * Note, the rows are not sorted in each partition of the resulting Dataset.
+ * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`.
+ * The resulting Dataset is range partitioned.
*
+ * At least one partition-by expression must be specified. When no explicit sort order is
+ * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each
+ * partition of the resulting Dataset.
*
- * Note that due to performance reasons this method uses sampling to estimate the ranges.
- * Hence, the output may not be consistent, since sampling can return different values.
- * The sample size can be controlled by the config
- * `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
+ * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence,
+ * the output may not be consistent, since sampling can return different values. The sample size
+ * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
*
* @group typedrel
* @since 2.3.0
*/
@scala.annotation.varargs
- def repartitionByRange(numPartitions: Int, partitionExprs: Column*): DS[T] = {
+ def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
repartitionByRange(Some(numPartitions), partitionExprs)
}
/**
* Returns a new Dataset partitioned by the given partitioning expressions, using
- * `spark.sql.shuffle.partitions` as number of partitions.
- * The resulting Dataset is range partitioned.
+ * `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is range
+ * partitioned.
*
- * At least one partition-by expression must be specified.
- * When no explicit sort order is specified, "ascending nulls first" is assumed.
- * Note, the rows are not sorted in each partition of the resulting Dataset.
+ * At least one partition-by expression must be specified. When no explicit sort order is
+ * specified, "ascending nulls first" is assumed. Note, the rows are not sorted in each
+ * partition of the resulting Dataset.
*
- * Note that due to performance reasons this method uses sampling to estimate the ranges.
- * Hence, the output may not be consistent, since sampling can return different values.
- * The sample size can be controlled by the config
- * `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
+ * Note that due to performance reasons this method uses sampling to estimate the ranges. Hence,
+ * the output may not be consistent, since sampling can return different values. The sample size
+ * can be controlled by the config `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
*
* @group typedrel
* @since 2.3.0
*/
@scala.annotation.varargs
- def repartitionByRange(partitionExprs: Column*): DS[T] = {
+ def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
repartitionByRange(None, partitionExprs)
}
/**
* Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
* are requested. If a larger number of partitions is requested, it will stay at the current
- * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in
- * a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not
- * be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
+ * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in a
+ * narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a
+ * shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
*
- * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
- * this may result in your computation taking place on fewer nodes than
- * you like (e.g. one node in the case of numPartitions = 1). To avoid this,
- * you can call repartition. This will add a shuffle step, but means the
- * current upstream partitions will be executed in parallel (per whatever
- * the current partitioning is).
+ * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, this may result in
+ * your computation taking place on fewer nodes than you like (e.g. one node in the case of
+ * numPartitions = 1). To avoid this, you can call repartition. This will add a shuffle step,
+ * but means the current upstream partitions will be executed in parallel (per whatever the
+ * current partitioning is).
*
* @group typedrel
* @since 1.6.0
*/
- def coalesce(numPartitions: Int): DS[T]
+ def coalesce(numPartitions: Int): Dataset[T]
/**
- * Returns a new Dataset that contains only the unique rows from this Dataset.
- * This is an alias for `dropDuplicates`.
+ * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias
+ * for `dropDuplicates`.
*
- * Note that for a streaming [[Dataset]], this method returns distinct rows only once
- * regardless of the output mode, which the behavior may not be same with `DISTINCT` in SQL
- * against streaming [[Dataset]].
+ * Note that for a streaming [[Dataset]], this method returns distinct rows only once regardless
+ * of the output mode, which the behavior may not be same with `DISTINCT` in SQL against
+ * streaming [[Dataset]].
*
- * @note Equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
+ * @note
+ * Equality checking is performed directly on the encoded representation of the data and thus
+ * is not affected by a custom `equals` function defined on `T`.
* @group typedrel
* @since 2.0.0
*/
- def distinct(): DS[T] = dropDuplicates()
+ def distinct(): Dataset[T] = dropDuplicates()
/**
* Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
@@ -2707,7 +2866,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group basic
* @since 1.6.0
*/
- def persist(): DS[T]
+ def persist(): Dataset[T]
/**
* Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
@@ -2715,19 +2874,18 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @group basic
* @since 1.6.0
*/
- def cache(): DS[T]
-
+ def cache(): Dataset[T]
/**
* Persist this Dataset with the given storage level.
*
- * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
- * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
- * `MEMORY_AND_DISK_2`, etc.
+ * @param newLevel
+ * One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, `MEMORY_AND_DISK_SER`,
+ * `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc.
* @group basic
* @since 1.6.0
*/
- def persist(newLevel: StorageLevel): DS[T]
+ def persist(newLevel: StorageLevel): Dataset[T]
/**
* Get the Dataset's current storage level, or StorageLevel.NONE if not persisted.
@@ -2738,23 +2896,24 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def storageLevel: StorageLevel
/**
- * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk.
- * This will not un-persist any cached data that is built upon this Dataset.
+ * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This
+ * will not un-persist any cached data that is built upon this Dataset.
*
- * @param blocking Whether to block until all blocks are deleted.
+ * @param blocking
+ * Whether to block until all blocks are deleted.
* @group basic
* @since 1.6.0
*/
- def unpersist(blocking: Boolean): DS[T]
+ def unpersist(blocking: Boolean): Dataset[T]
/**
- * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk.
- * This will not un-persist any cached data that is built upon this Dataset.
+ * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This
+ * will not un-persist any cached data that is built upon this Dataset.
*
* @group basic
* @since 1.6.0
*/
- def unpersist(): DS[T]
+ def unpersist(): Dataset[T]
/**
* Registers this Dataset as a temporary table using the given name. The lifetime of this
@@ -2769,14 +2928,15 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
}
/**
- * Creates a local temporary view using the given name. The lifetime of this
- * temporary view is tied to the `SparkSession` that was used to create this Dataset.
+ * Creates a local temporary view using the given name. The lifetime of this temporary view is
+ * tied to the `SparkSession` that was used to create this Dataset.
*
* Local temporary view is session-scoped. Its lifetime is the lifetime of the session that
- * created it, i.e. it will be automatically dropped when the session terminates. It's not
- * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view.
+ * created it, i.e. it will be automatically dropped when the session terminates. It's not tied
+ * to any databases, i.e. we can't use `db1.view1` to reference a local temporary view.
*
- * @throws org.apache.spark.sql.AnalysisException if the view name is invalid or already exists
+ * @throws org.apache.spark.sql.AnalysisException
+ * if the view name is invalid or already exists
* @group basic
* @since 2.0.0
*/
@@ -2785,10 +2945,9 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
createTempView(viewName, replace = false, global = false)
}
-
/**
- * Creates a local temporary view using the given name. The lifetime of this
- * temporary view is tied to the `SparkSession` that was used to create this Dataset.
+ * Creates a local temporary view using the given name. The lifetime of this temporary view is
+ * tied to the `SparkSession` that was used to create this Dataset.
*
* @group basic
* @since 2.0.0
@@ -2798,15 +2957,16 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
}
/**
- * Creates a global temporary view using the given name. The lifetime of this
- * temporary view is tied to this Spark application.
+ * Creates a global temporary view using the given name. The lifetime of this temporary view is
+ * tied to this Spark application.
*
- * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
- * i.e. it will be automatically dropped when the application terminates. It's tied to a system
- * preserved database `global_temp`, and we must use the qualified name to refer a global temp
- * view, e.g. `SELECT * FROM global_temp.view1`.
+ * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark
+ * application, i.e. it will be automatically dropped when the application terminates. It's tied
+ * to a system preserved database `global_temp`, and we must use the qualified name to refer a
+ * global temp view, e.g. `SELECT * FROM global_temp.view1`.
*
- * @throws org.apache.spark.sql.AnalysisException if the view name is invalid or already exists
+ * @throws org.apache.spark.sql.AnalysisException
+ * if the view name is invalid or already exists
* @group basic
* @since 2.1.0
*/
@@ -2819,10 +2979,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* Creates or replaces a global temporary view using the given name. The lifetime of this
* temporary view is tied to this Spark application.
*
- * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
- * i.e. it will be automatically dropped when the application terminates. It's tied to a system
- * preserved database `global_temp`, and we must use the qualified name to refer a global temp
- * view, e.g. `SELECT * FROM global_temp.view1`.
+ * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark
+ * application, i.e. it will be automatically dropped when the application terminates. It's tied
+ * to a system preserved database `global_temp`, and we must use the qualified name to refer a
+ * global temp view, e.g. `SELECT * FROM global_temp.view1`.
*
* @group basic
* @since 2.2.0
@@ -2833,17 +2993,63 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
protected def createTempView(viewName: String, replace: Boolean, global: Boolean): Unit
+ /**
+ * Merges a set of updates, insertions, and deletions based on a source table into a target
+ * table.
+ *
+ * Scala Examples:
+ * {{{
+ * spark.table("source")
+ * .mergeInto("target", $"source.id" === $"target.id")
+ * .whenMatched($"salary" === 100)
+ * .delete()
+ * .whenNotMatched()
+ * .insertAll()
+ * .whenNotMatchedBySource($"salary" === 100)
+ * .update(Map(
+ * "salary" -> lit(200)
+ * ))
+ * .merge()
+ * }}}
+ *
+ * @group basic
+ * @since 4.0.0
+ */
+ def mergeInto(table: String, condition: Column): MergeIntoWriter[T]
+
+ /**
+ * Create a write configuration builder for v2 sources.
+ *
+ * This builder is used to configure and execute write operations. For example, to append to an
+ * existing table, run:
+ *
+ * {{{
+ * df.writeTo("catalog.db.table").append()
+ * }}}
+ *
+ * This can also be used to create or replace existing tables:
+ *
+ * {{{
+ * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace()
+ * }}}
+ *
+ * @group basic
+ * @since 3.0.0
+ */
+ def writeTo(table: String): DataFrameWriterV2[T]
+
/**
* Returns the content of the Dataset as a Dataset of JSON strings.
*
* @since 2.0.0
*/
- def toJSON: DS[String]
+ def toJSON: Dataset[String]
/**
* Returns a best-effort snapshot of the files that compose this Dataset. This method simply
- * asks each constituent BaseRelation for its respective files and takes the union of all results.
- * Depending on the source relations, this may not find all input files. Duplicates are removed.
+ * asks each constituent BaseRelation for its respective files and takes the union of all
+ * results. Depending on the source relations, this may not find all input files. Duplicates are
+ * removed.
*
* @group basic
* @since 2.0.0
@@ -2851,14 +3057,16 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
def inputFiles: Array[String]
/**
- * Returns `true` when the logical query plans inside both [[Dataset]]s are equal and
- * therefore return same results.
+ * Returns `true` when the logical query plans inside both [[Dataset]]s are equal and therefore
+ * return same results.
*
- * @note The equality comparison here is simplified by tolerating the cosmetic differences
- * such as attribute names.
- * @note This API can compare both [[Dataset]]s very fast but can still return `false` on
- * the [[Dataset]] that return the same results, for instance, from different plans. Such
- * false negative semantic can be useful when caching as an example.
+ * @note
+ * The equality comparison here is simplified by tolerating the cosmetic differences such as
+ * attribute names.
+ * @note
+ * This API can compare both [[Dataset]]s very fast but can still return `false` on the
+ * [[Dataset]] that return the same results, for instance, from different plans. Such false
+ * negative semantic can be useful when caching as an example.
* @since 3.1.0
*/
@DeveloperApi
@@ -2867,8 +3075,9 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
/**
* Returns a `hashCode` of the logical query plan against this [[Dataset]].
*
- * @note Unlike the standard `hashCode`, the hash is calculated against the query plan
- * simplified by tolerating the cosmetic differences such as attribute names.
+ * @note
+ * Unlike the standard `hashCode`, the hash is calculated against the query plan simplified by
+ * tolerating the cosmetic differences such as attribute names.
* @since 3.1.0
*/
@DeveloperApi
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala
new file mode 100644
index 0000000000000..81f999430a128
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala
@@ -0,0 +1,1022 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.api
+
+import org.apache.spark.api.java.function.{CoGroupFunction, FlatMapGroupsFunction, FlatMapGroupsWithStateFunction, MapFunction, MapGroupsFunction, MapGroupsWithStateFunction, ReduceFunction}
+import org.apache.spark.sql.{Column, Encoder, TypedColumn}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder
+import org.apache.spark.sql.functions.{count => cnt, lit}
+import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors}
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode}
+
+/**
+ * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
+ * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on an
+ * existing [[Dataset]].
+ *
+ * @since 2.0.0
+ */
+abstract class KeyValueGroupedDataset[K, V] extends Serializable {
+ type KVDS[KL, VL] <: KeyValueGroupedDataset[KL, VL]
+
+ /**
+ * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
+ * specified type. The mapping of key columns to the type follows the same rules as `as` on
+ * [[Dataset]].
+ *
+ * @since 1.6.0
+ */
+ def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V]
+
+ /**
+ * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+ * the data. The grouping key is unchanged by this.
+ *
+ * {{{
+ * // Create values grouped by key from a Dataset[(K, V)]
+ * ds.groupByKey(_._1).mapValues(_._2) // Scala
+ * }}}
+ *
+ * @since 2.1.0
+ */
+ def mapValues[W: Encoder](func: V => W): KeyValueGroupedDataset[K, W]
+
+ /**
+ * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to
+ * the data. The grouping key is unchanged by this.
+ *
+ * {{{
+ * // Create Integer values grouped by String key from a Dataset>
+ * Dataset> ds = ...;
+ * KeyValueGroupedDataset grouped =
+ * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
+ * }}}
+ *
+ * @since 2.1.0
+ */
+ def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+ mapValues(ToScalaUDF(func))(encoder)
+ }
+
+ /**
+ * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping over
+ * the Dataset to extract the keys and then running a distinct operation on those.
+ *
+ * @since 1.6.0
+ */
+ def keys: Dataset[K]
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+ * function will be passed the group key and an iterator that contains all of the elements in
+ * the group. The function can return an iterator containing elements of an arbitrary type which
+ * will be returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the
+ * memory constraints of their cluster.
+ *
+ * @since 1.6.0
+ */
+ def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
+ flatMapSortedGroups(Nil: _*)(f)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data. For each unique group, the
+ * function will be passed the group key and an iterator that contains all of the elements in
+ * the group. The function can return an iterator containing elements of an arbitrary type which
+ * will be returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the
+ * memory constraints of their cluster.
+ *
+ * @since 1.6.0
+ */
+ def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+ flatMapGroups(ToScalaUDF(f))(encoder)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+ * function will be passed the group key and a sorted iterator that contains all of the elements
+ * in the group. The function can return an iterator containing elements of an arbitrary type
+ * which will be returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the
+ * memory constraints of their cluster.
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+ * sorted according to the given sort expressions. That sorting does not add computational
+ * complexity.
+ *
+ * @see
+ * `org.apache.spark.sql.api.KeyValueGroupedDataset#flatMapGroups`
+ * @since 3.4.0
+ */
+ def flatMapSortedGroups[U: Encoder](sortExprs: Column*)(
+ f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U]
+
+ /**
+ * (Java-specific) Applies the given function to each group of data. For each unique group, the
+ * function will be passed the group key and a sorted iterator that contains all of the elements
+ * in the group. The function can return an iterator containing elements of an arbitrary type
+ * which will be returned as a new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the
+ * memory constraints of their cluster.
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator to be
+ * sorted according to the given sort expressions. That sorting does not add computational
+ * complexity.
+ *
+ * @see
+ * `org.apache.spark.sql.api.KeyValueGroupedDataset#flatMapGroups`
+ * @since 3.4.0
+ */
+ def flatMapSortedGroups[U](
+ SortExprs: Array[Column],
+ f: FlatMapGroupsFunction[K, V, U],
+ encoder: Encoder[U]): Dataset[U] = {
+ import org.apache.spark.util.ArrayImplicits._
+ flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)(ToScalaUDF(f))(encoder)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data. For each unique group, the
+ * function will be passed the group key and an iterator that contains all of the elements in
+ * the group. The function can return an element of arbitrary type which will be returned as a
+ * new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the
+ * memory constraints of their cluster.
+ *
+ * @since 1.6.0
+ */
+ def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+ flatMapGroups(UDFAdaptors.mapGroupsToFlatMapGroups(f))
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data. For each unique group, the
+ * function will be passed the group key and an iterator that contains all of the elements in
+ * the group. The function can return an element of arbitrary type which will be returned as a
+ * new [[Dataset]].
+ *
+ * This function does not support partial aggregation, and as a result requires shuffling all
+ * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
+ * key, it is best to use the reduce function or an
+ * `org.apache.spark.sql.expressions#Aggregator`.
+ *
+ * Internally, the implementation will spill to disk if any given group is too large to fit into
+ * memory. However, users must take care to avoid materializing the whole iterator for a group
+ * (for example, by calling `toList`) unless they are sure that this is possible given the
+ * memory constraints of their cluster.
+ *
+ * @since 1.6.0
+ */
+ def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+ mapGroups(ToScalaUDF(f))(encoder)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See
+ * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 2.2.0
+ */
+ def mapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See
+ * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 2.2.0
+ */
+ def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(
+ func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See
+ * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param timeoutConf
+ * Timeout Conf, see GroupStateTimeout for more details
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to
+ * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 3.2.0
+ */
+ def mapGroupsWithState[S: Encoder, U: Encoder](
+ timeoutConf: GroupStateTimeout,
+ initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 2.2.0
+ */
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ mapGroupsWithState[S, U](ToScalaUDF(func))(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 2.2.0
+ */
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout): Dataset[U] = {
+ mapGroupsWithState[S, U](timeoutConf)(ToScalaUDF(func))(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 3.2.0
+ */
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout,
+ initialState: KVDS[K, S]): Dataset[U] = {
+ val f = ToScalaUDF(func)
+ mapGroupsWithState[S, U](timeoutConf, initialState)(f)(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 2.2.0
+ */
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ outputMode: OutputMode,
+ timeoutConf: GroupStateTimeout)(
+ func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group. To covert a Dataset `ds` of type of type
+ * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
+ * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[org.apache.spark.sql.Encoder]] for
+ * more details on what types are encodable to Spark SQL.
+ * @since 3.2.0
+ */
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ outputMode: OutputMode,
+ timeoutConf: GroupStateTimeout,
+ initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 2.2.0
+ */
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout): Dataset[U] = {
+ val f = ToScalaUDF(func)
+ flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group. To covert a Dataset `ds` of type of type
+ * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
+ * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ * @since 3.2.0
+ */
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout,
+ initialState: KVDS[K, S]): Dataset[U] = {
+ flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(ToScalaUDF(func))(
+ stateEncoder,
+ outputEncoder)
+ }
+
+ /**
+ * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
+ * API v2. We allow the user to act on per-group set of input rows along with keyed state and
+ * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will
+ * repeatedly invoke the interface methods for new rows in each trigger and the user's
+ * state/state variables will be stored persistently across invocations.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param timeMode
+ * The time mode semantics of the stateful processor for timers and TTL.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ timeMode: TimeMode,
+ outputMode: OutputMode): Dataset[U]
+
+ /**
+ * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
+ * API v2. We allow the user to act on per-group set of input rows along with keyed state and
+ * the user can choose to output/return 0 or more rows. For a streaming dataframe, we will
+ * repeatedly invoke the interface methods for new rows in each trigger and the user's
+ * state/state variables will be stored persistently across invocations.
+ *
+ * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note
+ * that TimeMode is set to EventTime to ensure correct flow of watermark.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param eventTimeColumnName
+ * eventTime column in the output dataset. Any operations after transformWithState will use
+ * the new eventTimeColumn. The user needs to ensure that the eventTime for emitted output
+ * adheres to the watermark boundary, otherwise streaming query will fail.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ eventTimeColumnName: String,
+ outputMode: OutputMode): Dataset[U]
+
+ /**
+ * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
+ * v2. We allow the user to act on per-group set of input rows along with keyed state and the
+ * user can choose to output/return 0 or more rows. For a streaming dataframe, we will
+ * repeatedly invoke the interface methods for new rows in each trigger and the user's
+ * state/state variables will be stored persistently across invocations.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param timeMode
+ * The time mode semantics of the stateful processor for timers and TTL.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ * @param outputEncoder
+ * Encoder for the output type.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder)
+ }
+
+ /**
+ * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
+ * v2. We allow the user to act on per-group set of input rows along with keyed state and the
+ * user can choose to output/return 0 or more rows.
+ *
+ * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows in
+ * each trigger and the user's state/state variables will be stored persistently across
+ * invocations.
+ *
+ * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note
+ * that TimeMode is set to EventTime to ensure correct flow of watermark.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param eventTimeColumnName
+ * eventTime column in the output dataset. Any operations after transformWithState will use
+ * the new eventTimeColumn. The user needs to ensure that the eventTime for emitted output
+ * adheres to the watermark boundary, otherwise streaming query will fail.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ * @param outputEncoder
+ * Encoder for the output type.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ eventTimeColumnName: String,
+ outputMode: OutputMode,
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ transformWithState(statefulProcessor, eventTimeColumnName, outputMode)(outputEncoder)
+ }
+
+ /**
+ * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
+ * API v2. Functions as the function above, but with additional initial state.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @tparam S
+ * The type of initial state objects. Must be encodable to Spark SQL types.
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param timeMode
+ * The time mode semantics of the stateful processor for timers and TTL.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ * @param initialState
+ * User provided initial state that will be used to initiate state for the query in the first
+ * batch.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ initialState: KVDS[K, S]): Dataset[U]
+
+ /**
+ * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
+ * API v2. Functions as the function above, but with additional eventTimeColumnName for output.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @tparam S
+ * The type of initial state objects. Must be encodable to Spark SQL types.
+ *
+ * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note
+ * that TimeMode is set to EventTime to ensure correct flow of watermark.
+ *
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param eventTimeColumnName
+ * eventTime column in the output dataset. Any operations after transformWithState will use
+ * the new eventTimeColumn. The user needs to ensure that the eventTime for emitted output
+ * adheres to the watermark boundary, otherwise streaming query will fail.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ * @param initialState
+ * User provided initial state that will be used to initiate state for the query in the first
+ * batch.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ eventTimeColumnName: String,
+ outputMode: OutputMode,
+ initialState: KVDS[K, S]): Dataset[U]
+
+ /**
+ * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
+ * v2. Functions as the function above, but with additional initialStateEncoder for state
+ * encoding.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @tparam S
+ * The type of initial state objects. Must be encodable to Spark SQL types.
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param timeMode
+ * The time mode semantics of the stateful processor for timers and TTL.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ * @param initialState
+ * User provided initial state that will be used to initiate state for the query in the first
+ * batch.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param initialStateEncoder
+ * Encoder for the initial state type.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ initialState: KVDS[K, S],
+ outputEncoder: Encoder[U],
+ initialStateEncoder: Encoder[S]): Dataset[U] = {
+ transformWithState(statefulProcessor, timeMode, outputMode, initialState)(
+ outputEncoder,
+ initialStateEncoder)
+ }
+
+ /**
+ * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
+ * v2. Functions as the function above, but with additional eventTimeColumnName for output.
+ *
+ * Downstream operators would use specified eventTimeColumnName to calculate watermark. Note
+ * that TimeMode is set to EventTime to ensure correct flow of watermark.
+ *
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @tparam S
+ * The type of initial state objects. Must be encodable to Spark SQL types.
+ * @param statefulProcessor
+ * Instance of statefulProcessor whose functions will be invoked by the operator.
+ * @param outputMode
+ * The output mode of the stateful processor.
+ * @param initialState
+ * User provided initial state that will be used to initiate state for the query in the first
+ * batch.
+ * @param eventTimeColumnName
+ * event column in the output dataset. Any operations after transformWithState will use the
+ * new eventTimeColumn. The user needs to ensure that the eventTime for emitted output adheres
+ * to the watermark boundary, otherwise streaming query will fail.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param initialStateEncoder
+ * Encoder for the initial state type.
+ *
+ * See [[org.apache.spark.sql.Encoder]] for more details on what types are encodable to Spark
+ * SQL.
+ */
+ private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ outputMode: OutputMode,
+ initialState: KVDS[K, S],
+ eventTimeColumnName: String,
+ outputEncoder: Encoder[U],
+ initialStateEncoder: Encoder[S]): Dataset[U] = {
+ transformWithState(statefulProcessor, eventTimeColumnName, outputMode, initialState)(
+ outputEncoder,
+ initialStateEncoder)
+ }
+
+ /**
+ * (Scala-specific) Reduces the elements of each group of data using the specified binary
+ * function. The given function must be commutative and associative or the result may be
+ * non-deterministic.
+ *
+ * @since 1.6.0
+ */
+ def reduceGroups(f: (V, V) => V): Dataset[(K, V)]
+
+ /**
+ * (Java-specific) Reduces the elements of each group of data using the specified binary
+ * function. The given function must be commutative and associative or the result may be
+ * non-deterministic.
+ *
+ * @since 1.6.0
+ */
+ def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
+ reduceGroups(ToScalaUDF(f))
+ }
+
+ /**
+ * Internal helper function for building typed aggregations that return tuples. For simplicity
+ * and code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ */
+ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_]
+
+ /**
+ * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the
+ * result of computing this aggregation over all elements in the group.
+ *
+ * @since 1.6.0
+ */
+ def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
+ aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
+ * the result of computing these aggregations over all elements in the group.
+ *
+ * @since 1.6.0
+ */
+ def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
+ aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
+ * the result of computing these aggregations over all elements in the group.
+ *
+ * @since 1.6.0
+ */
+ def agg[U1, U2, U3](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
+ aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
+ * the result of computing these aggregations over all elements in the group.
+ *
+ * @since 1.6.0
+ */
+ def agg[U1, U2, U3, U4](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
+ aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
+ * the result of computing these aggregations over all elements in the group.
+ *
+ * @since 3.0.0
+ */
+ def agg[U1, U2, U3, U4, U5](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
+ aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
+ * the result of computing these aggregations over all elements in the group.
+ *
+ * @since 3.0.0
+ */
+ def agg[U1, U2, U3, U4, U5, U6](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5],
+ col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
+ aggUntyped(col1, col2, col3, col4, col5, col6)
+ .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
+ * the result of computing these aggregations over all elements in the group.
+ *
+ * @since 3.0.0
+ */
+ def agg[U1, U2, U3, U4, U5, U6, U7](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5],
+ col6: TypedColumn[V, U6],
+ col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
+ aggUntyped(col1, col2, col3, col4, col5, col6, col7)
+ .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and
+ * the result of computing these aggregations over all elements in the group.
+ *
+ * @since 3.0.0
+ */
+ def agg[U1, U2, U3, U4, U5, U6, U7, U8](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2],
+ col3: TypedColumn[V, U3],
+ col4: TypedColumn[V, U4],
+ col5: TypedColumn[V, U5],
+ col6: TypedColumn[V, U6],
+ col7: TypedColumn[V, U7],
+ col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
+ aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8)
+ .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]]
+
+ /**
+ * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for
+ * that key.
+ *
+ * @since 1.6.0
+ */
+ def count(): Dataset[(K, Long)] = agg(cnt(lit(1)).as(PrimitiveLongEncoder))
+
+ /**
+ * (Scala-specific) Applies the given function to each cogrouped data. For each unique group,
+ * the function will be passed the grouping key and 2 iterators containing all elements in the
+ * group from [[Dataset]] `this` and `other`. The function can return an iterator containing
+ * elements of an arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * @since 1.6.0
+ */
+ def cogroup[U, R: Encoder](other: KVDS[K, U])(
+ f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
+ cogroupSorted(other)(Nil: _*)(Nil: _*)(f)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each cogrouped data. For each unique group, the
+ * function will be passed the grouping key and 2 iterators containing all elements in the group
+ * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
+ * of an arbitrary type which will be returned as a new [[Dataset]].
+ *
+ * @since 1.6.0
+ */
+ def cogroup[U, R](
+ other: KVDS[K, U],
+ f: CoGroupFunction[K, V, U, R],
+ encoder: Encoder[R]): Dataset[R] = {
+ cogroup(other)(ToScalaUDF(f))(encoder)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each sorted cogrouped data. For each unique
+ * group, the function will be passed the grouping key and 2 sorted iterators containing all
+ * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+ * iterator containing elements of an arbitrary type which will be returned as a new
+ * [[Dataset]].
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+ * sorted according to the given sort expressions. That sorting does not add computational
+ * complexity.
+ *
+ * @see
+ * `org.apache.spark.sql.api.KeyValueGroupedDataset#cogroup`
+ * @since 3.4.0
+ */
+ def cogroupSorted[U, R: Encoder](other: KVDS[K, U])(thisSortExprs: Column*)(
+ otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R]
+
+ /**
+ * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique
+ * group, the function will be passed the grouping key and 2 sorted iterators containing all
+ * elements in the group from [[Dataset]] `this` and `other`. The function can return an
+ * iterator containing elements of an arbitrary type which will be returned as a new
+ * [[Dataset]].
+ *
+ * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators to be
+ * sorted according to the given sort expressions. That sorting does not add computational
+ * complexity.
+ *
+ * @see
+ * `org.apache.spark.sql.api.KeyValueGroupedDataset#cogroup`
+ * @since 3.4.0
+ */
+ def cogroupSorted[U, R](
+ other: KVDS[K, U],
+ thisSortExprs: Array[Column],
+ otherSortExprs: Array[Column],
+ f: CoGroupFunction[K, V, U, R],
+ encoder: Encoder[R]): Dataset[R] = {
+ import org.apache.spark.util.ArrayImplicits._
+ cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)(
+ otherSortExprs.toImmutableArraySeq: _*)(ToScalaUDF(f))(encoder)
+ }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala
index 30b2992d43a00..118b8f1ecd488 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._
import _root_.java.util
import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.{functions, Column, Row}
+import org.apache.spark.sql.{functions, Column, Encoder, Row}
/**
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
@@ -30,19 +30,18 @@ import org.apache.spark.sql.{functions, Column, Row}
* The main method is the `agg` function, which has multiple variants. This class also contains
* some first-order statistics such as `mean`, `sum` for convenience.
*
- * @note This class was named `GroupedData` in Spark 1.x.
+ * @note
+ * This class was named `GroupedData` in Spark 1.x.
* @since 2.0.0
*/
@Stable
-abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
- type RGD <: RelationalGroupedDataset[DS]
-
- protected def df: DS[Row]
+abstract class RelationalGroupedDataset {
+ protected def df: Dataset[Row]
/**
* Create a aggregation based on the grouping column, the grouping type, and the aggregations.
*/
- protected def toDF(aggCols: Seq[Column]): DS[Row]
+ protected def toDF(aggCols: Seq[Column]): Dataset[Row]
protected def selectNumericColumns(colNames: Seq[String]): Seq[Column]
@@ -61,13 +60,21 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
private def aggregateNumericColumns(
colNames: Seq[String],
- function: Column => Column): DS[Row] = {
+ function: Column => Column): Dataset[Row] = {
toDF(selectNumericColumns(colNames).map(function))
}
/**
- * (Scala-specific) Compute aggregates by specifying the column names and
- * aggregate methods. The resulting `DataFrame` will also contain the grouping columns.
+ * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions of
+ * current `RelationalGroupedDataset`.
+ *
+ * @since 3.0.0
+ */
+ def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T]
+
+ /**
+ * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The
+ * resulting `DataFrame` will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
@@ -80,12 +87,12 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
*
* @since 1.3.0
*/
- def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] =
+ def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] =
toDF((aggExpr +: aggExprs).map(toAggCol))
/**
- * (Scala-specific) Compute aggregates by specifying a map from column name to
- * aggregate methods. The resulting `DataFrame` will also contain the grouping columns.
+ * (Scala-specific) Compute aggregates by specifying a map from column name to aggregate
+ * methods. The resulting `DataFrame` will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
@@ -98,11 +105,11 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
*
* @since 1.3.0
*/
- def agg(exprs: Map[String, String]): DS[Row] = toDF(exprs.map(toAggCol).toSeq)
+ def agg(exprs: Map[String, String]): Dataset[Row] = toDF(exprs.map(toAggCol).toSeq)
/**
- * (Java-specific) Compute aggregates by specifying a map from column name to
- * aggregate methods. The resulting `DataFrame` will also contain the grouping columns.
+ * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods.
+ * The resulting `DataFrame` will also contain the grouping columns.
*
* The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`.
* {{{
@@ -113,7 +120,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
*
* @since 1.3.0
*/
- def agg(exprs: util.Map[String, String]): DS[Row] = {
+ def agg(exprs: util.Map[String, String]): Dataset[Row] = {
agg(exprs.asScala.toMap)
}
@@ -149,91 +156,92 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
* @since 1.3.0
*/
@scala.annotation.varargs
- def agg(expr: Column, exprs: Column*): DS[Row] = toDF(expr +: exprs)
+ def agg(expr: Column, exprs: Column*): Dataset[Row] = toDF(expr +: exprs)
/**
- * Count the number of rows for each group.
- * The resulting `DataFrame` will also contain the grouping columns.
+ * Count the number of rows for each group. The resulting `DataFrame` will also contain the
+ * grouping columns.
*
* @since 1.3.0
*/
- def count(): DS[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil)
+ def count(): Dataset[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil)
/**
- * Compute the average value for each numeric columns for each group. This is an alias for `avg`.
- * The resulting `DataFrame` will also contain the grouping columns.
- * When specified columns are given, only compute the average values for them.
+ * Compute the average value for each numeric columns for each group. This is an alias for
+ * `avg`. The resulting `DataFrame` will also contain the grouping columns. When specified
+ * columns are given, only compute the average values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
- def mean(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg)
-
+ def mean(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg)
/**
- * Compute the max value for each numeric columns for each group.
- * The resulting `DataFrame` will also contain the grouping columns.
- * When specified columns are given, only compute the max values for them.
+ * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will
+ * also contain the grouping columns. When specified columns are given, only compute the max
+ * values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
- def max(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.max)
+ def max(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.max)
/**
- * Compute the mean value for each numeric columns for each group.
- * The resulting `DataFrame` will also contain the grouping columns.
- * When specified columns are given, only compute the mean values for them.
+ * Compute the mean value for each numeric columns for each group. The resulting `DataFrame`
+ * will also contain the grouping columns. When specified columns are given, only compute the
+ * mean values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
- def avg(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg)
+ def avg(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg)
/**
- * Compute the min value for each numeric column for each group.
- * The resulting `DataFrame` will also contain the grouping columns.
- * When specified columns are given, only compute the min values for them.
+ * Compute the min value for each numeric column for each group. The resulting `DataFrame` will
+ * also contain the grouping columns. When specified columns are given, only compute the min
+ * values for them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
- def min(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.min)
+ def min(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.min)
/**
- * Compute the sum for each numeric columns for each group.
- * The resulting `DataFrame` will also contain the grouping columns.
- * When specified columns are given, only compute the sum for them.
+ * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also
+ * contain the grouping columns. When specified columns are given, only compute the sum for
+ * them.
*
* @since 1.3.0
*/
@scala.annotation.varargs
- def sum(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.sum)
+ def sum(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.sum)
/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
*
- * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine
- * the resulting schema of the transformation. To avoid any eager computations, provide an
- * explicit list of values via `pivot(pivotColumn: String, values: Seq[Any])`.
+ * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the
+ * resulting schema of the transformation. To avoid any eager computations, provide an explicit
+ * list of values via `pivot(pivotColumn: String, values: Seq[Any])`.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
- * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
- * except for the aggregation.
- * @param pivotColumn Name of the column to pivot.
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ * @param pivotColumn
+ * Name of the column to pivot.
* @since 1.6.0
*/
- def pivot(pivotColumn: String): RGD = pivot(df.col(pivotColumn))
+ def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(df.col(pivotColumn))
/**
- * Pivots a column of the current `DataFrame` and performs the specified aggregation.
- * There are two versions of pivot function: one that requires the caller to specify the list
- * of distinct values to pivot on, and one that does not. The latter is more concise but less
- * efficient, because Spark needs to first compute the list of distinct values internally.
+ * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are
+ * two versions of pivot function: one that requires the caller to specify the list of distinct
+ * values to pivot on, and one that does not. The latter is more concise but less efficient,
+ * because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
@@ -252,21 +260,24 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
* .agg(sum($"earnings"))
* }}}
*
- * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
- * except for the aggregation.
- * @param pivotColumn Name of the column to pivot.
- * @param values List of values that will be translated to columns in the output DataFrame.
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ * @param pivotColumn
+ * Name of the column to pivot.
+ * @param values
+ * List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
- def pivot(pivotColumn: String, values: Seq[Any]): RGD =
+ def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset =
pivot(df.col(pivotColumn), values)
/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
* aggregation.
*
- * There are two versions of pivot function: one that requires the caller to specify the list
- * of distinct values to pivot on, and one that does not. The latter is more concise but less
+ * There are two versions of pivot function: one that requires the caller to specify the list of
+ * distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
@@ -277,62 +288,73 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] {
* df.groupBy("year").pivot("course").sum("earnings");
* }}}
*
- * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
- * except for the aggregation.
- * @param pivotColumn Name of the column to pivot.
- * @param values List of values that will be translated to columns in the output DataFrame.
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ * @param pivotColumn
+ * Name of the column to pivot.
+ * @param values
+ * List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
- def pivot(pivotColumn: String, values: util.List[Any]): RGD =
+ def pivot(pivotColumn: String, values: util.List[Any]): RelationalGroupedDataset =
pivot(df.col(pivotColumn), values)
/**
* (Java-specific) Pivots a column of the current `DataFrame` and performs the specified
- * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of
- * the `String` type.
- *
- * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
- * except for the aggregation.
- * @param pivotColumn the column to pivot.
- * @param values List of values that will be translated to columns in the output DataFrame.
+ * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of the
+ * `String` type.
+ *
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ * @param pivotColumn
+ * the column to pivot.
+ * @param values
+ * List of values that will be translated to columns in the output DataFrame.
* @since 2.4.0
*/
- def pivot(pivotColumn: Column, values: util.List[Any]): RGD =
+ def pivot(pivotColumn: Column, values: util.List[Any]): RelationalGroupedDataset =
pivot(pivotColumn, values.asScala.toSeq)
/**
* Pivots a column of the current `DataFrame` and performs the specified aggregation.
*
- * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine
- * the resulting schema of the transformation. To avoid any eager computations, provide an
- * explicit list of values via `pivot(pivotColumn: Column, values: Seq[Any])`.
+ * Spark will eagerly compute the distinct values in `pivotColumn` so it can determine the
+ * resulting schema of the transformation. To avoid any eager computations, provide an explicit
+ * list of values via `pivot(pivotColumn: Column, values: Seq[Any])`.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy($"year").pivot($"course").sum($"earnings");
* }}}
*
- * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
- * except for the aggregation.
- * @param pivotColumn he column to pivot.
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ * @param pivotColumn
+ * he column to pivot.
* @since 2.4.0
*/
- def pivot(pivotColumn: Column): RGD
+ def pivot(pivotColumn: Column): RelationalGroupedDataset
/**
- * Pivots a column of the current `DataFrame` and performs the specified aggregation.
- * This is an overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
+ * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an
+ * overloaded version of the `pivot` method with `pivotColumn` of the `String` type.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings")
* }}}
*
- * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
- * except for the aggregation.
- * @param pivotColumn the column to pivot.
- * @param values List of values that will be translated to columns in the output DataFrame.
+ * @see
+ * `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, except for the
+ * aggregation.
+ * @param pivotColumn
+ * the column to pivot.
+ * @param values
+ * List of values that will be translated to columns in the output DataFrame.
* @since 2.4.0
*/
- def pivot(pivotColumn: Column, values: Seq[Any]): RGD
+ def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
index d156aba934b68..41d16b16ab1c5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
@@ -26,14 +26,14 @@ import _root_.java.net.URI
import _root_.java.util
import org.apache.spark.annotation.{DeveloperApi, Experimental}
-import org.apache.spark.sql.{Encoder, Row}
+import org.apache.spark.sql.{Encoder, Row, RuntimeConfig}
import org.apache.spark.sql.types.StructType
/**
* The entry point to programming Spark with the Dataset and DataFrame API.
*
- * In environments that this has been created upfront (e.g. REPL, notebooks), use the builder
- * to get an existing session:
+ * In environments that this has been created upfront (e.g. REPL, notebooks), use the builder to
+ * get an existing session:
*
* {{{
* SparkSession.builder().getOrCreate()
@@ -49,7 +49,8 @@ import org.apache.spark.sql.types.StructType
* .getOrCreate()
* }}}
*/
-abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with Closeable {
+abstract class SparkSession extends Serializable with Closeable {
+
/**
* The version of Spark on which this application is running.
*
@@ -57,6 +58,17 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
*/
def version: String
+ /**
+ * Runtime configuration interface for Spark.
+ *
+ * This is the interface through which the user can get and set all Spark and Hadoop
+ * configurations that are relevant to Spark SQL. When getting the value of a config, this
+ * defaults to the value set in the underlying `SparkContext`, if any.
+ *
+ * @since 2.0.0
+ */
+ def conf: RuntimeConfig
+
/**
* A collection of methods for registering user-defined functions (UDF).
*
@@ -72,24 +84,26 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
* DataTypes.StringType);
* }}}
*
- * @note The user-defined functions must be deterministic. Due to optimization,
- * duplicate invocations may be eliminated or the function may even be invoked more times
- * than it is present in the query.
+ * @note
+ * The user-defined functions must be deterministic. Due to optimization, duplicate
+ * invocations may be eliminated or the function may even be invoked more times than it is
+ * present in the query.
* @since 2.0.0
*/
def udf: UDFRegistration
/**
- * Start a new session with isolated SQL configurations, temporary tables, registered
- * functions are isolated, but sharing the underlying `SparkContext` and cached data.
- *
- * @note Other than the `SparkContext`, all shared state is initialized lazily.
- * This method will force the initialization of the shared state to ensure that parent
- * and child sessions are set up with the same shared state. If the underlying catalog
- * implementation is Hive, this will initialize the metastore, which may take some time.
+ * Start a new session with isolated SQL configurations, temporary tables, registered functions
+ * are isolated, but sharing the underlying `SparkContext` and cached data.
+ *
+ * @note
+ * Other than the `SparkContext`, all shared state is initialized lazily. This method will
+ * force the initialization of the shared state to ensure that parent and child sessions are
+ * set up with the same shared state. If the underlying catalog implementation is Hive, this
+ * will initialize the metastore, which may take some time.
* @since 2.0.0
*/
- def newSession(): SparkSession[DS]
+ def newSession(): SparkSession
/* --------------------------------- *
| Methods for creating DataFrames |
@@ -101,36 +115,35 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
* @since 2.0.0
*/
@transient
- def emptyDataFrame: DS[Row]
+ def emptyDataFrame: Dataset[Row]
/**
* Creates a `DataFrame` from a local Seq of Product.
*
* @since 2.0.0
*/
- def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DS[Row]
+ def createDataFrame[A <: Product: TypeTag](data: Seq[A]): Dataset[Row]
/**
- * :: DeveloperApi ::
- * Creates a `DataFrame` from a `java.util.List` containing [[org.apache.spark.sql.Row]]s using
- * the given schema.It is important to make sure that the structure of every
- * [[org.apache.spark.sql.Row]] of the provided List matches the provided schema. Otherwise,
- * there will be runtime exception.
+ * :: DeveloperApi :: Creates a `DataFrame` from a `java.util.List` containing
+ * [[org.apache.spark.sql.Row]]s using the given schema.It is important to make sure that the
+ * structure of every [[org.apache.spark.sql.Row]] of the provided List matches the provided
+ * schema. Otherwise, there will be runtime exception.
*
* @since 2.0.0
*/
@DeveloperApi
- def createDataFrame(rows: util.List[Row], schema: StructType): DS[Row]
+ def createDataFrame(rows: util.List[Row], schema: StructType): Dataset[Row]
/**
* Applies a schema to a List of Java Beans.
*
- * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
- * SELECT * queries will return the columns in an undefined order.
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, SELECT * queries
+ * will return the columns in an undefined order.
*
* @since 1.6.0
*/
- def createDataFrame(data: util.List[_], beanClass: Class[_]): DS[Row]
+ def createDataFrame(data: util.List[_], beanClass: Class[_]): Dataset[Row]
/* ------------------------------- *
| Methods for creating DataSets |
@@ -141,15 +154,15 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
*
* @since 2.0.0
*/
- def emptyDataset[T: Encoder]: DS[T]
+ def emptyDataset[T: Encoder]: Dataset[T]
/**
* Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an
- * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation)
- * that is generally created automatically through implicits from a `SparkSession`, or can be
- * created explicitly by calling static methods on `Encoders`.
+ * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL
+ * representation) that is generally created automatically through implicits from a
+ * `SparkSession`, or can be created explicitly by calling static methods on `Encoders`.
*
- * == Example ==
+ * ==Example==
*
* {{{
*
@@ -170,15 +183,15 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
*
* @since 2.0.0
*/
- def createDataset[T: Encoder](data: Seq[T]): DS[T]
+ def createDataset[T: Encoder](data: Seq[T]): Dataset[T]
/**
* Creates a [[Dataset]] from a `java.util.List` of a given type. This method requires an
- * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation)
- * that is generally created automatically through implicits from a `SparkSession`, or can be
- * created explicitly by calling static methods on `Encoders`.
+ * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL
+ * representation) that is generally created automatically through implicits from a
+ * `SparkSession`, or can be created explicitly by calling static methods on `Encoders`.
*
- * == Java Example ==
+ * ==Java Example==
*
* {{{
* List data = Arrays.asList("hello", "world");
@@ -187,131 +200,134 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
*
* @since 2.0.0
*/
- def createDataset[T: Encoder](data: util.List[T]): DS[T]
+ def createDataset[T: Encoder](data: util.List[T]): Dataset[T]
/**
- * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements
- * in a range from 0 to `end` (exclusive) with step value 1.
+ * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
+ * range from 0 to `end` (exclusive) with step value 1.
*
* @since 2.0.0
*/
- def range(end: Long): DS[lang.Long]
+ def range(end: Long): Dataset[lang.Long]
/**
- * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements
- * in a range from `start` to `end` (exclusive) with step value 1.
+ * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
+ * range from `start` to `end` (exclusive) with step value 1.
*
* @since 2.0.0
*/
- def range(start: Long, end: Long): DS[lang.Long]
+ def range(start: Long, end: Long): Dataset[lang.Long]
/**
- * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements
- * in a range from `start` to `end` (exclusive) with a step value.
+ * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
+ * range from `start` to `end` (exclusive) with a step value.
*
* @since 2.0.0
*/
- def range(start: Long, end: Long, step: Long): DS[lang.Long]
+ def range(start: Long, end: Long, step: Long): Dataset[lang.Long]
/**
- * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements
- * in a range from `start` to `end` (exclusive) with a step value, with partition number
- * specified.
+ * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
+ * range from `start` to `end` (exclusive) with a step value, with partition number specified.
*
* @since 2.0.0
*/
- def range(start: Long, end: Long, step: Long, numPartitions: Int): DS[lang.Long]
+ def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[lang.Long]
/* ------------------------- *
| Catalog-related methods |
* ------------------------- */
+ /**
+ * Interface through which the user may create, drop, alter or query underlying databases,
+ * tables, functions etc.
+ *
+ * @since 2.0.0
+ */
+ def catalog: Catalog
+
/**
* Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch
- * reading and the returned DataFrame is the batch scan query plan of this table. If it's a view,
- * the returned DataFrame is simply the query plan of the view, which can either be a batch or
- * streaming query plan.
- *
- * @param tableName is either a qualified or unqualified name that designates a table or view.
- * If a database is specified, it identifies the table/view from the database.
- * Otherwise, it first attempts to find a temporary view with the given name
- * and then match the table/view from the current database.
- * Note that, the global temporary view database is also valid here.
+ * reading and the returned DataFrame is the batch scan query plan of this table. If it's a
+ * view, the returned DataFrame is simply the query plan of the view, which can either be a
+ * batch or streaming query plan.
+ *
+ * @param tableName
+ * is either a qualified or unqualified name that designates a table or view. If a database is
+ * specified, it identifies the table/view from the database. Otherwise, it first attempts to
+ * find a temporary view with the given name and then match the table/view from the current
+ * database. Note that, the global temporary view database is also valid here.
* @since 2.0.0
*/
- def table(tableName: String): DS[Row]
+ def table(tableName: String): Dataset[Row]
/* ----------------- *
| Everything else |
* ----------------- */
/**
- * Executes a SQL query substituting positional parameters by the given arguments,
- * returning the result as a `DataFrame`.
- * This API eagerly runs DDL/DML commands, but not for SELECT queries.
- *
- * @param sqlText A SQL statement with positional parameters to execute.
- * @param args An array of Java/Scala objects that can be converted to
- * SQL literal expressions. See
- *
- * Supported Data Types for supported value types in Scala/Java.
- * For example, 1, "Steven", LocalDate.of(2023, 4, 2).
- * A value can be also a `Column` of a literal or collection constructor functions
- * such as `map()`, `array()`, `struct()`, in that case it is taken as is.
+ * Executes a SQL query substituting positional parameters by the given arguments, returning the
+ * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
+ *
+ * @param sqlText
+ * A SQL statement with positional parameters to execute.
+ * @param args
+ * An array of Java/Scala objects that can be converted to SQL literal expressions. See Supported Data
+ * Types for supported value types in Scala/Java. For example, 1, "Steven",
+ * LocalDate.of(2023, 4, 2). A value can be also a `Column` of a literal or collection
+ * constructor functions such as `map()`, `array()`, `struct()`, in that case it is taken as
+ * is.
* @since 3.5.0
*/
@Experimental
- def sql(sqlText: String, args: Array[_]): DS[Row]
-
- /**
- * Executes a SQL query substituting named parameters by the given arguments,
- * returning the result as a `DataFrame`.
- * This API eagerly runs DDL/DML commands, but not for SELECT queries.
- *
- * @param sqlText A SQL statement with named parameters to execute.
- * @param args A map of parameter names to Java/Scala objects that can be converted to
- * SQL literal expressions. See
- *
- * Supported Data Types for supported value types in Scala/Java.
- * For example, map keys: "rank", "name", "birthdate";
- * map values: 1, "Steven", LocalDate.of(2023, 4, 2).
- * Map value can be also a `Column` of a literal or collection constructor
- * functions such as `map()`, `array()`, `struct()`, in that case it is taken
- * as is.
+ def sql(sqlText: String, args: Array[_]): Dataset[Row]
+
+ /**
+ * Executes a SQL query substituting named parameters by the given arguments, returning the
+ * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
+ *
+ * @param sqlText
+ * A SQL statement with named parameters to execute.
+ * @param args
+ * A map of parameter names to Java/Scala objects that can be converted to SQL literal
+ * expressions. See
+ * Supported Data Types for supported value types in Scala/Java. For example, map keys:
+ * "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value
+ * can be also a `Column` of a literal or collection constructor functions such as `map()`,
+ * `array()`, `struct()`, in that case it is taken as is.
* @since 3.4.0
*/
@Experimental
- def sql(sqlText: String, args: Map[String, Any]): DS[Row]
-
- /**
- * Executes a SQL query substituting named parameters by the given arguments,
- * returning the result as a `DataFrame`.
- * This API eagerly runs DDL/DML commands, but not for SELECT queries.
- *
- * @param sqlText A SQL statement with named parameters to execute.
- * @param args A map of parameter names to Java/Scala objects that can be converted to
- * SQL literal expressions. See
- *
- * Supported Data Types for supported value types in Scala/Java.
- * For example, map keys: "rank", "name", "birthdate";
- * map values: 1, "Steven", LocalDate.of(2023, 4, 2).
- * Map value can be also a `Column` of a literal or collection constructor
- * functions such as `map()`, `array()`, `struct()`, in that case it is taken
- * as is.
+ def sql(sqlText: String, args: Map[String, Any]): Dataset[Row]
+
+ /**
+ * Executes a SQL query substituting named parameters by the given arguments, returning the
+ * result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
+ *
+ * @param sqlText
+ * A SQL statement with named parameters to execute.
+ * @param args
+ * A map of parameter names to Java/Scala objects that can be converted to SQL literal
+ * expressions. See
+ * Supported Data Types for supported value types in Scala/Java. For example, map keys:
+ * "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value
+ * can be also a `Column` of a literal or collection constructor functions such as `map()`,
+ * `array()`, `struct()`, in that case it is taken as is.
* @since 3.4.0
*/
@Experimental
- def sql(sqlText: String, args: util.Map[String, Any]): DS[Row] = {
+ def sql(sqlText: String, args: util.Map[String, Any]): Dataset[Row] = {
sql(sqlText, args.asScala.toMap)
}
/**
- * Executes a SQL query using Spark, returning the result as a `DataFrame`.
- * This API eagerly runs DDL/DML commands, but not for SELECT queries.
+ * Executes a SQL query using Spark, returning the result as a `DataFrame`. This API eagerly
+ * runs DDL/DML commands, but not for SELECT queries.
*
* @since 2.0.0
*/
- def sql(sqlText: String): DS[Row] = sql(sqlText, Map.empty[String, Any])
+ def sql(sqlText: String): Dataset[Row] = sql(sqlText, Map.empty[String, Any])
/**
* Add a single artifact to the current session.
@@ -333,6 +349,24 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
@Experimental
def addArtifact(uri: URI): Unit
+ /**
+ * Add a single in-memory artifact to the session while preserving the directory structure
+ * specified by `target` under the session's working directory of that particular file
+ * extension.
+ *
+ * Supported target file extensions are .jar and .class.
+ *
+ * ==Example==
+ * {{{
+ * addArtifact(bytesBar, "foo/bar.class")
+ * addArtifact(bytesFlat, "flat.class")
+ * // Directory structure of the session's working directory for class files would look like:
+ * // ${WORKING_DIR_FOR_CLASS_FILES}/flat.class
+ * // ${WORKING_DIR_FOR_CLASS_FILES}/foo/bar.class
+ * }}}
+ *
+ * @since 4.0.0
+ */
@Experimental
def addArtifact(bytes: Array[Byte], target: String): Unit
@@ -367,6 +401,110 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
@scala.annotation.varargs
def addArtifacts(uri: URI*): Unit
+ /**
+ * Add a tag to be assigned to all the operations started by this thread in this session.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark executions.
+ * Application programmers can use this method to group all those jobs together and give a group
+ * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all
+ * running executions with this tag. For example:
+ * {{{
+ * // In the main thread:
+ * spark.addTag("myjobs")
+ * spark.range(10).map(i => { Thread.sleep(10); i }).collect()
+ *
+ * // In a separate thread:
+ * spark.interruptTag("myjobs")
+ * }}}
+ *
+ * There may be multiple tags present at the same time, so different parts of application may
+ * use different tags to perform cancellation at different levels of granularity.
+ *
+ * @param tag
+ * The tag to be added. Cannot contain ',' (comma) character or be an empty string.
+ *
+ * @since 4.0.0
+ */
+ def addTag(tag: String): Unit
+
+ /**
+ * Remove a tag previously added to be assigned to all the operations started by this thread in
+ * this session. Noop if such a tag was not added earlier.
+ *
+ * @param tag
+ * The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
+ *
+ * @since 4.0.0
+ */
+ def removeTag(tag: String): Unit
+
+ /**
+ * Get the operation tags that are currently set to be assigned to all the operations started by
+ * this thread in this session.
+ *
+ * @since 4.0.0
+ */
+ def getTags(): Set[String]
+
+ /**
+ * Clear the current thread's operation tags.
+ *
+ * @since 4.0.0
+ */
+ def clearTags(): Unit
+
+ /**
+ * Request to interrupt all currently running operations of this session.
+ *
+ * @note
+ * This method will wait up to 60 seconds for the interruption request to be issued.
+ *
+ * @return
+ * Sequence of operation IDs requested to be interrupted.
+ *
+ * @since 4.0.0
+ */
+ def interruptAll(): Seq[String]
+
+ /**
+ * Request to interrupt all currently running operations of this session with the given job tag.
+ *
+ * @note
+ * This method will wait up to 60 seconds for the interruption request to be issued.
+ *
+ * @return
+ * Sequence of operation IDs requested to be interrupted.
+ *
+ * @since 4.0.0
+ */
+ def interruptTag(tag: String): Seq[String]
+
+ /**
+ * Request to interrupt an operation of this session, given its operation ID.
+ *
+ * @note
+ * This method will wait up to 60 seconds for the interruption request to be issued.
+ *
+ * @return
+ * The operation ID requested to be interrupted, as a single-element sequence, or an empty
+ * sequence if the operation is not started by this session.
+ *
+ * @since 4.0.0
+ */
+ def interruptOperation(operationId: String): Seq[String]
+
+ /**
+ * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
+ * `DataFrame`.
+ * {{{
+ * sparkSession.read.parquet("/path/to/file.parquet")
+ * sparkSession.read.schema(schema).json("/path/to/file.json")
+ * }}}
+ *
+ * @since 2.0.0
+ */
+ def read: DataFrameReader
+
/**
* Executes some code block and prints to stdout the time taken to execute the block. This is
* available in Scala only and is used primarily for interactive testing and debugging.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala
new file mode 100644
index 0000000000000..0aeb3518facd8
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api
+
+import _root_.java.util.UUID
+import _root_.java.util.concurrent.TimeoutException
+
+import org.apache.spark.annotation.Evolving
+import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus}
+
+/**
+ * A handle to a query that is executing continuously in the background as new data arrives. All
+ * these methods are thread-safe.
+ * @since 2.0.0
+ */
+@Evolving
+trait StreamingQuery {
+
+ /**
+ * Returns the user-specified name of the query, or null if not specified. This name can be
+ * specified in the `org.apache.spark.sql.streaming.DataStreamWriter` as
+ * `dataframe.writeStream.queryName("query").start()`. This name, if set, must be unique across
+ * all active queries.
+ *
+ * @since 2.0.0
+ */
+ def name: String
+
+ /**
+ * Returns the unique id of this query that persists across restarts from checkpoint data. That
+ * is, this id is generated when a query is started for the first time, and will be the same
+ * every time it is restarted from checkpoint data. Also see [[runId]].
+ *
+ * @since 2.1.0
+ */
+ def id: UUID
+
+ /**
+ * Returns the unique id of this run of the query. That is, every start/restart of a query will
+ * generate a unique runId. Therefore, every time a query is restarted from checkpoint, it will
+ * have the same [[id]] but different [[runId]]s.
+ */
+ def runId: UUID
+
+ /**
+ * Returns the `SparkSession` associated with `this`.
+ *
+ * @since 2.0.0
+ */
+ def sparkSession: SparkSession
+
+ /**
+ * Returns `true` if this query is actively running.
+ *
+ * @since 2.0.0
+ */
+ def isActive: Boolean
+
+ /**
+ * Returns the [[org.apache.spark.sql.streaming.StreamingQueryException]] if the query was
+ * terminated by an exception.
+ *
+ * @since 2.0.0
+ */
+ def exception: Option[StreamingQueryException]
+
+ /**
+ * Returns the current status of the query.
+ *
+ * @since 2.0.2
+ */
+ def status: StreamingQueryStatus
+
+ /**
+ * Returns an array of the most recent [[org.apache.spark.sql.streaming.StreamingQueryProgress]]
+ * updates for this query. The number of progress updates retained for each stream is configured
+ * by Spark session configuration `spark.sql.streaming.numRecentProgressUpdates`.
+ *
+ * @since 2.1.0
+ */
+ def recentProgress: Array[StreamingQueryProgress]
+
+ /**
+ * Returns the most recent [[org.apache.spark.sql.streaming.StreamingQueryProgress]] update of
+ * this streaming query.
+ *
+ * @since 2.1.0
+ */
+ def lastProgress: StreamingQueryProgress
+
+ /**
+ * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If
+ * the query has terminated with an exception, then the exception will be thrown.
+ *
+ * If the query has terminated, then all subsequent calls to this method will either return
+ * immediately (if the query was terminated by `stop()`), or throw the exception immediately (if
+ * the query has terminated with exception).
+ *
+ * @throws org.apache.spark.sql.streaming.StreamingQueryException
+ * if the query has terminated with an exception.
+ *
+ * @since 2.0.0
+ */
+ @throws[StreamingQueryException]
+ def awaitTermination(): Unit
+
+ /**
+ * Waits for the termination of `this` query, either by `query.stop()` or by an exception. If
+ * the query has terminated with an exception, then the exception will be thrown. Otherwise, it
+ * returns whether the query has terminated or not within the `timeoutMs` milliseconds.
+ *
+ * If the query has terminated, then all subsequent calls to this method will either return
+ * `true` immediately (if the query was terminated by `stop()`), or throw the exception
+ * immediately (if the query has terminated with exception).
+ *
+ * @throws org.apache.spark.sql.streaming.StreamingQueryException
+ * if the query has terminated with an exception
+ *
+ * @since 2.0.0
+ */
+ @throws[StreamingQueryException]
+ def awaitTermination(timeoutMs: Long): Boolean
+
+ /**
+ * Blocks until all available data in the source has been processed and committed to the sink.
+ * This method is intended for testing. Note that in the case of continually arriving data, this
+ * method may block forever. Additionally, this method is only guaranteed to block until data
+ * that has been synchronously appended data to a
+ * `org.apache.spark.sql.execution.streaming.Source` prior to invocation. (i.e. `getOffset` must
+ * immediately reflect the addition).
+ * @since 2.0.0
+ */
+ def processAllAvailable(): Unit
+
+ /**
+ * Stops the execution of this query if it is running. This waits until the termination of the
+ * query execution threads or until a timeout is hit.
+ *
+ * By default stop will block indefinitely. You can configure a timeout by the configuration
+ * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block
+ * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the
+ * issue persists, it is advisable to kill the Spark application.
+ *
+ * @since 2.0.0
+ */
+ @throws[TimeoutException]
+ def stop(): Unit
+
+ /**
+ * Prints the physical plan to the console for debugging purposes.
+ * @since 2.0.0
+ */
+ def explain(): Unit
+
+ /**
+ * Prints the physical plan to the console for debugging purposes.
+ *
+ * @param extended
+ * whether to do extended explain or not
+ * @since 2.0.0
+ */
+ def explain(extended: Boolean): Unit
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala
index 4611393f0f7ec..c11e266827ff9 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/UDFRegistration.scala
@@ -50,9 +50,12 @@ abstract class UDFRegistration {
* spark.udf.register("stringLit", bar.asNonNullable())
* }}}
*
- * @param name the name of the UDF.
- * @param udf the UDF needs to be registered.
- * @return the registered UDF.
+ * @param name
+ * the name of the UDF.
+ * @param udf
+ * the UDF needs to be registered.
+ * @return
+ * the registered UDF.
*
* @since 2.2.0
*/
@@ -117,11 +120,12 @@ abstract class UDFRegistration {
| registerJavaUDF(name, ToScalaUDF(f), returnType, $i)
|}""".stripMargin)
}
- */
+ */
/**
* Registers a deterministic Scala closure of 0 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = {
@@ -130,200 +134,964 @@ abstract class UDFRegistration {
/**
* Registers a deterministic Scala closure of 1 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
+ def register[RT: TypeTag, A1: TypeTag](
+ name: String,
+ func: Function1[A1, RT]): UserDefinedFunction = {
registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]])
}
/**
* Registers a deterministic Scala closure of 2 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]])
+ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](
+ name: String,
+ func: Function2[A1, A2, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]])
}
/**
* Registers a deterministic Scala closure of 3 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]])
+ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
+ name: String,
+ func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]])
}
/**
* Registers a deterministic Scala closure of 4 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]])
+ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
+ name: String,
+ func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]])
}
/**
* Registers a deterministic Scala closure of 5 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]])
+ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
+ name: String,
+ func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]])
}
/**
* Registers a deterministic Scala closure of 6 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag](
+ name: String,
+ func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]])
}
/**
* Registers a deterministic Scala closure of 7 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag](
+ name: String,
+ func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]])
}
/**
* Registers a deterministic Scala closure of 8 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag](
+ name: String,
+ func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]])
}
/**
* Registers a deterministic Scala closure of 9 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag](
+ name: String,
+ func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]])
}
/**
* Registers a deterministic Scala closure of 10 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag](
+ name: String,
+ func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]])
}
/**
* Registers a deterministic Scala closure of 11 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag](
+ name: String,
+ func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]])
}
/**
* Registers a deterministic Scala closure of 12 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag](
+ name: String,
+ func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT])
+ : UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]])
}
/**
* Registers a deterministic Scala closure of 13 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag](
+ name: String,
+ func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT])
+ : UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]])
}
/**
* Registers a deterministic Scala closure of 14 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag](
+ name: String,
+ func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT])
+ : UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]])
}
/**
* Registers a deterministic Scala closure of 15 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag](
+ name: String,
+ func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT])
+ : UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]])
}
/**
* Registers a deterministic Scala closure of 16 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag,
+ A16: TypeTag](
+ name: String,
+ func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT])
+ : UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]],
+ implicitly[TypeTag[A16]])
}
/**
* Registers a deterministic Scala closure of 17 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag,
+ A16: TypeTag,
+ A17: TypeTag](
+ name: String,
+ func: Function17[
+ A1,
+ A2,
+ A3,
+ A4,
+ A5,
+ A6,
+ A7,
+ A8,
+ A9,
+ A10,
+ A11,
+ A12,
+ A13,
+ A14,
+ A15,
+ A16,
+ A17,
+ RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]],
+ implicitly[TypeTag[A16]],
+ implicitly[TypeTag[A17]])
}
/**
* Registers a deterministic Scala closure of 18 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag,
+ A16: TypeTag,
+ A17: TypeTag,
+ A18: TypeTag](
+ name: String,
+ func: Function18[
+ A1,
+ A2,
+ A3,
+ A4,
+ A5,
+ A6,
+ A7,
+ A8,
+ A9,
+ A10,
+ A11,
+ A12,
+ A13,
+ A14,
+ A15,
+ A16,
+ A17,
+ A18,
+ RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]],
+ implicitly[TypeTag[A16]],
+ implicitly[TypeTag[A17]],
+ implicitly[TypeTag[A18]])
}
/**
* Registers a deterministic Scala closure of 19 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag,
+ A16: TypeTag,
+ A17: TypeTag,
+ A18: TypeTag,
+ A19: TypeTag](
+ name: String,
+ func: Function19[
+ A1,
+ A2,
+ A3,
+ A4,
+ A5,
+ A6,
+ A7,
+ A8,
+ A9,
+ A10,
+ A11,
+ A12,
+ A13,
+ A14,
+ A15,
+ A16,
+ A17,
+ A18,
+ A19,
+ RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]],
+ implicitly[TypeTag[A16]],
+ implicitly[TypeTag[A17]],
+ implicitly[TypeTag[A18]],
+ implicitly[TypeTag[A19]])
}
/**
* Registers a deterministic Scala closure of 20 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag,
+ A16: TypeTag,
+ A17: TypeTag,
+ A18: TypeTag,
+ A19: TypeTag,
+ A20: TypeTag](
+ name: String,
+ func: Function20[
+ A1,
+ A2,
+ A3,
+ A4,
+ A5,
+ A6,
+ A7,
+ A8,
+ A9,
+ A10,
+ A11,
+ A12,
+ A13,
+ A14,
+ A15,
+ A16,
+ A17,
+ A18,
+ A19,
+ A20,
+ RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]],
+ implicitly[TypeTag[A16]],
+ implicitly[TypeTag[A17]],
+ implicitly[TypeTag[A18]],
+ implicitly[TypeTag[A19]],
+ implicitly[TypeTag[A20]])
}
/**
* Registers a deterministic Scala closure of 21 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag,
+ A16: TypeTag,
+ A17: TypeTag,
+ A18: TypeTag,
+ A19: TypeTag,
+ A20: TypeTag,
+ A21: TypeTag](
+ name: String,
+ func: Function21[
+ A1,
+ A2,
+ A3,
+ A4,
+ A5,
+ A6,
+ A7,
+ A8,
+ A9,
+ A10,
+ A11,
+ A12,
+ A13,
+ A14,
+ A15,
+ A16,
+ A17,
+ A18,
+ A19,
+ A20,
+ A21,
+ RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]],
+ implicitly[TypeTag[A16]],
+ implicitly[TypeTag[A17]],
+ implicitly[TypeTag[A18]],
+ implicitly[TypeTag[A19]],
+ implicitly[TypeTag[A20]],
+ implicitly[TypeTag[A21]])
}
/**
* Registers a deterministic Scala closure of 22 arguments as user-defined function (UDF).
- * @tparam RT return type of UDF.
+ * @tparam RT
+ * return type of UDF.
* @since 1.3.0
*/
- def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = {
- registerScalaUDF(name, func, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]], implicitly[TypeTag[A11]], implicitly[TypeTag[A12]], implicitly[TypeTag[A13]], implicitly[TypeTag[A14]], implicitly[TypeTag[A15]], implicitly[TypeTag[A16]], implicitly[TypeTag[A17]], implicitly[TypeTag[A18]], implicitly[TypeTag[A19]], implicitly[TypeTag[A20]], implicitly[TypeTag[A21]], implicitly[TypeTag[A22]])
+ def register[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag,
+ A11: TypeTag,
+ A12: TypeTag,
+ A13: TypeTag,
+ A14: TypeTag,
+ A15: TypeTag,
+ A16: TypeTag,
+ A17: TypeTag,
+ A18: TypeTag,
+ A19: TypeTag,
+ A20: TypeTag,
+ A21: TypeTag,
+ A22: TypeTag](
+ name: String,
+ func: Function22[
+ A1,
+ A2,
+ A3,
+ A4,
+ A5,
+ A6,
+ A7,
+ A8,
+ A9,
+ A10,
+ A11,
+ A12,
+ A13,
+ A14,
+ A15,
+ A16,
+ A17,
+ A18,
+ A19,
+ A20,
+ A21,
+ A22,
+ RT]): UserDefinedFunction = {
+ registerScalaUDF(
+ name,
+ func,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]],
+ implicitly[TypeTag[A11]],
+ implicitly[TypeTag[A12]],
+ implicitly[TypeTag[A13]],
+ implicitly[TypeTag[A14]],
+ implicitly[TypeTag[A15]],
+ implicitly[TypeTag[A16]],
+ implicitly[TypeTag[A17]],
+ implicitly[TypeTag[A18]],
+ implicitly[TypeTag[A19]],
+ implicitly[TypeTag[A20]],
+ implicitly[TypeTag[A21]],
+ implicitly[TypeTag[A22]])
}
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -405,7 +1173,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF9 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF9[_, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 9)
}
@@ -413,7 +1184,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF10 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF10[_, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 10)
}
@@ -421,7 +1195,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF11 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 11)
}
@@ -429,7 +1206,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF12 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 12)
}
@@ -437,7 +1217,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF13 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 13)
}
@@ -445,7 +1228,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF14 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 14)
}
@@ -453,7 +1239,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF15 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 15)
}
@@ -461,7 +1250,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF16 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 16)
}
@@ -469,7 +1261,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF17 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 17)
}
@@ -477,7 +1272,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF18 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 18)
}
@@ -485,7 +1283,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF19 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 19)
}
@@ -493,7 +1294,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF20 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 20)
}
@@ -501,7 +1305,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF21 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 21)
}
@@ -509,7 +1316,10 @@ abstract class UDFRegistration {
* Register a deterministic Java UDF22 instance as user-defined function (UDF).
* @since 1.3.0
*/
- def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
+ def register(
+ name: String,
+ f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): Unit = {
registerJavaUDF(name, ToScalaUDF(f), returnType, 22)
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalog/interface.scala
similarity index 95%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/interface.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/catalog/interface.scala
index 33e9007ac7e2b..3a3ba9d261326 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/interface.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalog/interface.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalog
import javax.annotation.Nullable
+import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.DefinedByConstructorParams
// Note: all classes here are expected to be wrapped in Datasets and so must extend
@@ -31,16 +32,13 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams
* name of the catalog
* @param description
* description of the catalog
- * @since 3.5.0
+ * @since 3.4.0
*/
class CatalogMetadata(val name: String, @Nullable val description: String)
extends DefinedByConstructorParams {
- override def toString: String = {
- "Catalog[" +
- s"name='$name', " +
- Option(description).map { d => s"description='$d'] " }.getOrElse("]")
- }
+ override def toString: String =
+ s"Catalog[name='$name', ${Option(description).map(d => s"description='$d'").getOrElse("")}]"
}
/**
@@ -54,8 +52,9 @@ class CatalogMetadata(val name: String, @Nullable val description: String)
* description of the database.
* @param locationUri
* path (in the form of a uri) to data files.
- * @since 3.5.0
+ * @since 2.0.0
*/
+@Stable
class Database(
val name: String,
@Nullable val catalog: String,
@@ -92,8 +91,9 @@ class Database(
* type of the table (e.g. view, table).
* @param isTemporary
* whether the table is a temporary table.
- * @since 3.5.0
+ * @since 2.0.0
*/
+@Stable
class Table(
val name: String,
@Nullable val catalog: String,
@@ -155,8 +155,9 @@ class Table(
* whether the column is a bucket column.
* @param isCluster
* whether the column is a clustering column.
- * @since 3.5.0
+ * @since 2.0.0
*/
+@Stable
class Column(
val name: String,
@Nullable val description: String,
@@ -205,8 +206,9 @@ class Column(
* the fully qualified class name of the function.
* @param isTemporary
* whether the function is a temporary function or not.
- * @since 3.5.0
+ * @since 2.0.0
*/
+@Stable
class Function(
val name: String,
@Nullable val catalog: String,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala
index fc6bc2095a821..efd9d457ef7d4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/DefinedByConstructorParams.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.catalyst
/**
- * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s
- * for classes whose fields are entirely defined by constructor params but should not be
- * case classes.
+ * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s for
+ * classes whose fields are entirely defined by constructor params but should not be case classes.
*/
trait DefinedByConstructorParams
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index f85e96da2be11..8d0103ca69635 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -36,10 +36,13 @@ import org.apache.spark.util.ArrayImplicits._
* Type-inference utilities for POJOs and Java collections.
*/
object JavaTypeInference {
+
/**
* Infers the corresponding SQL data type of a Java type.
- * @param beanType Java type
- * @return (SQL data type, nullable)
+ * @param beanType
+ * Java type
+ * @return
+ * (SQL data type, nullable)
*/
def inferDataType(beanType: Type): (DataType, Boolean) = {
val encoder = encoderFor(beanType)
@@ -60,8 +63,10 @@ object JavaTypeInference {
encoderFor(beanType, Set.empty).asInstanceOf[AgnosticEncoder[T]]
}
- private def encoderFor(t: Type, seenTypeSet: Set[Class[_]],
- typeVariables: Map[TypeVariable[_], Type] = Map.empty): AgnosticEncoder[_] = t match {
+ private def encoderFor(
+ t: Type,
+ seenTypeSet: Set[Class[_]],
+ typeVariables: Map[TypeVariable[_], Type] = Map.empty): AgnosticEncoder[_] = t match {
case c: Class[_] if c == java.lang.Boolean.TYPE => PrimitiveBooleanEncoder
case c: Class[_] if c == java.lang.Byte.TYPE => PrimitiveByteEncoder
@@ -94,14 +99,22 @@ object JavaTypeInference {
case c: Class[_] if c.isEnum => JavaEnumEncoder(ClassTag(c))
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
- val udt = c.getAnnotation(classOf[SQLUserDefinedType]).udt()
- .getConstructor().newInstance().asInstanceOf[UserDefinedType[Any]]
+ val udt = c
+ .getAnnotation(classOf[SQLUserDefinedType])
+ .udt()
+ .getConstructor()
+ .newInstance()
+ .asInstanceOf[UserDefinedType[Any]]
val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
UDTEncoder(udt, udtClass)
case c: Class[_] if UDTRegistration.exists(c.getName) =>
- val udt = UDTRegistration.getUDTFor(c.getName).get.getConstructor().
- newInstance().asInstanceOf[UserDefinedType[Any]]
+ val udt = UDTRegistration
+ .getUDTFor(c.getName)
+ .get
+ .getConstructor()
+ .newInstance()
+ .asInstanceOf[UserDefinedType[Any]]
UDTEncoder(udt, udt.getClass)
case c: Class[_] if c.isArray =>
@@ -160,7 +173,8 @@ object JavaTypeInference {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
- beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ beanInfo.getPropertyDescriptors
+ .filterNot(_.getName == "class")
.filterNot(_.getName == "declaringClass")
.filter(_.getReadMethod != null)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index f204421b0add2..cd12cbd267cc4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -55,9 +55,8 @@ object ScalaReflection extends ScalaReflection {
import scala.collection.Map
/**
- * Synchronize to prevent concurrent usage of `<:<` operator.
- * This operator is not thread safe in any current version of scala; i.e.
- * (2.11.12, 2.12.10, 2.13.0-M5).
+ * Synchronize to prevent concurrent usage of `<:<` operator. This operator is not thread safe
+ * in any current version of scala; i.e. (2.11.12, 2.12.10, 2.13.0-M5).
*
* See https://github.com/scala/bug/issues/10766
*/
@@ -91,11 +90,11 @@ object ScalaReflection extends ScalaReflection {
/**
* Workaround for [[https://github.com/scala/bug/issues/12190 Scala bug #12190]]
*
- * `ClassSymbol.selfType` can throw an exception in case of cyclic annotation reference
- * in Java classes. A retry of this operation will succeed as the class which defines the
- * cycle is now resolved. It can however expose further recursive annotation references, so
- * we keep retrying until we exhaust our retry threshold. Default threshold is set to 5
- * to allow for a few level of cyclic references.
+ * `ClassSymbol.selfType` can throw an exception in case of cyclic annotation reference in Java
+ * classes. A retry of this operation will succeed as the class which defines the cycle is now
+ * resolved. It can however expose further recursive annotation references, so we keep retrying
+ * until we exhaust our retry threshold. Default threshold is set to 5 to allow for a few level
+ * of cyclic references.
*/
@tailrec
private def selfType(clsSymbol: ClassSymbol, tries: Int = 5): Type = {
@@ -107,7 +106,7 @@ object ScalaReflection extends ScalaReflection {
// Retry on Symbols#CyclicReference if we haven't exhausted our retry limit
selfType(clsSymbol, tries - 1)
case Failure(e: RuntimeException)
- if e.getMessage.contains("illegal cyclic reference") && tries > 1 =>
+ if e.getMessage.contains("illegal cyclic reference") && tries > 1 =>
// UnPickler.unpickle wraps the original Symbols#CyclicReference exception into a runtime
// exception and does not set the cause, so we inspect the message. The previous case
// statement is useful for Java classes while this one is for Scala classes.
@@ -131,14 +130,14 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Returns the full class name for a type. The returned name is the canonical
- * Scala name, where each component is separated by a period. It is NOT the
- * Java-equivalent runtime name (no dollar signs).
+ * Returns the full class name for a type. The returned name is the canonical Scala name, where
+ * each component is separated by a period. It is NOT the Java-equivalent runtime name (no
+ * dollar signs).
*
- * In simple cases, both the Scala and Java names are the same, however when Scala
- * generates constructs that do not map to a Java equivalent, such as singleton objects
- * or nested classes in package objects, it uses the dollar sign ($) to create
- * synthetic classes, emulating behaviour in Java bytecode.
+ * In simple cases, both the Scala and Java names are the same, however when Scala generates
+ * constructs that do not map to a Java equivalent, such as singleton objects or nested classes
+ * in package objects, it uses the dollar sign ($) to create synthetic classes, emulating
+ * behaviour in Java bytecode.
*/
def getClassNameFromType(tpe: `Type`): String = {
erasure(tpe).dealias.typeSymbol.asClass.fullName
@@ -152,20 +151,25 @@ object ScalaReflection extends ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)
- /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
+ /**
+ * Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
+ */
def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
- /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
+ /**
+ * Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
+ */
def schemaFor(tpe: `Type`): Schema = {
val enc = encoderFor(tpe)
Schema(enc.dataType, enc.nullable)
}
/**
- * Finds an accessible constructor with compatible parameters. This is a more flexible search than
- * the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
+ * Finds an accessible constructor with compatible parameters. This is a more flexible search
+ * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
* matching constructor is returned if it exists. Otherwise, we check for additional compatible
- * constructors defined in the companion object as `apply` methods. Otherwise, it returns `None`.
+ * constructors defined in the companion object as `apply` methods. Otherwise, it returns
+ * `None`.
*/
def findConstructor[T](cls: Class[T], paramTypes: Seq[Class[_]]): Option[Seq[AnyRef] => T] = {
Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) match {
@@ -174,24 +178,28 @@ object ScalaReflection extends ScalaReflection {
val companion = mirror.staticClass(cls.getName).companion
val moduleMirror = mirror.reflectModule(companion.asModule)
val applyMethods = companion.asTerm.typeSignature
- .member(universe.TermName("apply")).asTerm.alternatives
- applyMethods.find { method =>
- val params = method.typeSignature.paramLists.head
- // Check that the needed params are the same length and of matching types
- params.size == paramTypes.size &&
- params.zip(paramTypes).forall { case(ps, pc) =>
- ps.typeSignature.typeSymbol == mirror.classSymbol(pc)
+ .member(universe.TermName("apply"))
+ .asTerm
+ .alternatives
+ applyMethods
+ .find { method =>
+ val params = method.typeSignature.paramLists.head
+ // Check that the needed params are the same length and of matching types
+ params.size == paramTypes.size &&
+ params.zip(paramTypes).forall { case (ps, pc) =>
+ ps.typeSignature.typeSymbol == mirror.classSymbol(pc)
+ }
}
- }.map { applyMethodSymbol =>
- val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size
- val instanceMirror = mirror.reflect(moduleMirror.instance)
- val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod)
- (_args: Seq[AnyRef]) => {
- // Drop the "outer" argument if it is provided
- val args = if (_args.size == expectedArgsCount) _args else _args.tail
- method.apply(args: _*).asInstanceOf[T]
+ .map { applyMethodSymbol =>
+ val expectedArgsCount = applyMethodSymbol.typeSignature.paramLists.head.size
+ val instanceMirror = mirror.reflect(moduleMirror.instance)
+ val method = instanceMirror.reflectMethod(applyMethodSymbol.asMethod)
+ (_args: Seq[AnyRef]) => {
+ // Drop the "outer" argument if it is provided
+ val args = if (_args.size == expectedArgsCount) _args else _args.tail
+ method.apply(args: _*).asInstanceOf[T]
+ }
}
- }
}
}
@@ -201,8 +209,10 @@ object ScalaReflection extends ScalaReflection {
def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects {
tpe.dealias match {
// `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type.
- case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head)
- case _ => isSubtype(tpe.dealias, localTypeOf[Product]) ||
+ case t if isSubtype(t, localTypeOf[Option[_]]) =>
+ definedByConstructorParams(t.typeArgs.head)
+ case _ =>
+ isSubtype(tpe.dealias, localTypeOf[Product]) ||
isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams])
}
}
@@ -214,16 +224,15 @@ object ScalaReflection extends ScalaReflection {
/**
* Create an [[AgnosticEncoder]] from a [[TypeTag]].
*
- * If the given type is not supported, i.e. there is no encoder can be built for this type,
- * an [[SparkUnsupportedOperationException]] will be thrown with detailed error message to
- * explain the type path walked so far and which class we are not supporting.
- * There are 4 kinds of type path:
- * * the root type: `root class: "abc.xyz.MyClass"`
- * * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"`
- * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
- * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
+ * If the given type is not supported, i.e. there is no encoder can be built for this type, an
+ * [[SparkUnsupportedOperationException]] will be thrown with detailed error message to explain
+ * the type path walked so far and which class we are not supporting. There are 4 kinds of type
+ * path: * the root type: `root class: "abc.xyz.MyClass"` * the value type of [[Option]]:
+ * `option value class: "abc.xyz.MyClass"` * the element type of [[Array]] or [[Seq]]: `array
+ * element class: "abc.xyz.MyClass"` * the field of [[Product]]: `field (class:
+ * "abc.xyz.MyClass", name: "myField")`
*/
- def encoderFor[E : TypeTag]: AgnosticEncoder[E] = {
+ def encoderFor[E: TypeTag]: AgnosticEncoder[E] = {
encoderFor(typeTag[E].in(mirror).tpe).asInstanceOf[AgnosticEncoder[E]]
}
@@ -239,13 +248,12 @@ object ScalaReflection extends ScalaReflection {
/**
* Create an [[AgnosticEncoder]] for a [[Type]].
*/
- def encoderFor(
- tpe: `Type`,
- isRowEncoderSupported: Boolean = false): AgnosticEncoder[_] = cleanUpReflectionObjects {
- val clsName = getClassNameFromType(tpe)
- val walkedTypePath = WalkedTypePath().recordRoot(clsName)
- encoderFor(tpe, Set.empty, walkedTypePath, isRowEncoderSupported)
- }
+ def encoderFor(tpe: `Type`, isRowEncoderSupported: Boolean = false): AgnosticEncoder[_] =
+ cleanUpReflectionObjects {
+ val clsName = getClassNameFromType(tpe)
+ val walkedTypePath = WalkedTypePath().recordRoot(clsName)
+ encoderFor(tpe, Set.empty, walkedTypePath, isRowEncoderSupported)
+ }
private def encoderFor(
tpe: `Type`,
@@ -327,14 +335,22 @@ object ScalaReflection extends ScalaReflection {
// UDT encoders
case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
- getConstructor().newInstance().asInstanceOf[UserDefinedType[Any]]
+ val udt = getClassFromType(t)
+ .getAnnotation(classOf[SQLUserDefinedType])
+ .udt()
+ .getConstructor()
+ .newInstance()
+ .asInstanceOf[UserDefinedType[Any]]
val udtClass = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
UDTEncoder(udt, udtClass)
case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
- newInstance().asInstanceOf[UserDefinedType[Any]]
+ val udt = UDTRegistration
+ .getUDTFor(getClassNameFromType(t))
+ .get
+ .getConstructor()
+ .newInstance()
+ .asInstanceOf[UserDefinedType[Any]]
UDTEncoder(udt, udt.getClass)
// Complex encoders
@@ -380,20 +396,17 @@ object ScalaReflection extends ScalaReflection {
if (seenTypeSet.contains(t)) {
throw ExecutionErrors.cannotHaveCircularReferencesInClassError(t.toString)
}
- val params = getConstructorParameters(t).map {
- case (fieldName, fieldType) =>
- if (SourceVersion.isKeyword(fieldName) ||
- !SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) {
- throw ExecutionErrors.cannotUseInvalidJavaIdentifierAsFieldNameError(
- fieldName,
- path)
- }
- val encoder = encoderFor(
- fieldType,
- seenTypeSet + t,
- path.recordField(getClassNameFromType(fieldType), fieldName),
- isRowEncoderSupported)
- EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
+ val params = getConstructorParameters(t).map { case (fieldName, fieldType) =>
+ if (SourceVersion.isKeyword(fieldName) ||
+ !SourceVersion.isIdentifier(encodeFieldNameToIdentifier(fieldName))) {
+ throw ExecutionErrors.cannotUseInvalidJavaIdentifierAsFieldNameError(fieldName, path)
+ }
+ val encoder = encoderFor(
+ fieldType,
+ seenTypeSet + t,
+ path.recordField(getClassNameFromType(fieldType), fieldName),
+ isRowEncoderSupported)
+ EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
}
val cls = getClassFromType(t)
ProductEncoder(ClassTag(cls), params, Option(OuterScopes.getOuterScope(cls)))
@@ -404,10 +417,11 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Support for generating catalyst schemas for scala objects. Note that unlike its companion
+ * Support for generating catalyst schemas for scala objects. Note that unlike its companion
* object, this trait able to work in both the runtime and the compile time (macro) universe.
*/
trait ScalaReflection extends Logging {
+
/** The universe we work in (runtime or macro) */
val universe: scala.reflect.api.Universe
@@ -421,7 +435,8 @@ trait ScalaReflection extends Logging {
* clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to
* `scala.reflect.runtime.JavaUniverse.undoLog`.
*
- * @see https://github.com/scala/bug/issues/8302
+ * @see
+ * https://github.com/scala/bug/issues/8302
*/
def cleanUpReflectionObjects[T](func: => T): T = {
universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func)
@@ -430,12 +445,13 @@ trait ScalaReflection extends Logging {
/**
* Return the Scala Type for `T` in the current classloader mirror.
*
- * Use this method instead of the convenience method `universe.typeOf`, which
- * assumes that all types can be found in the classloader that loaded scala-reflect classes.
- * That's not necessarily the case when running using Eclipse launchers or even
- * Sbt console or test (without `fork := true`).
+ * Use this method instead of the convenience method `universe.typeOf`, which assumes that all
+ * types can be found in the classloader that loaded scala-reflect classes. That's not
+ * necessarily the case when running using Eclipse launchers or even Sbt console or test
+ * (without `fork := true`).
*
- * @see SPARK-5281
+ * @see
+ * SPARK-5281
*/
def localTypeOf[T: TypeTag]: `Type` = {
val tag = implicitly[TypeTag[T]]
@@ -474,8 +490,8 @@ trait ScalaReflection extends Logging {
}
/**
- * If our type is a Scala trait it may have a companion object that
- * only defines a constructor via `apply` method.
+ * If our type is a Scala trait it may have a companion object that only defines a constructor
+ * via `apply` method.
*/
private def getCompanionConstructor(tpe: Type): Symbol = {
def throwUnsupportedOperation = {
@@ -483,10 +499,11 @@ trait ScalaReflection extends Logging {
}
tpe.typeSymbol.asClass.companion match {
case NoSymbol => throwUnsupportedOperation
- case sym => sym.asTerm.typeSignature.member(universe.TermName("apply")) match {
- case NoSymbol => throwUnsupportedOperation
- case constructorSym => constructorSym
- }
+ case sym =>
+ sym.asTerm.typeSignature.member(universe.TermName("apply")) match {
+ case NoSymbol => throwUnsupportedOperation
+ case constructorSym => constructorSym
+ }
}
}
@@ -499,8 +516,9 @@ trait ScalaReflection extends Logging {
constructorSymbol.asMethod.paramLists
} else {
// Find the primary constructor, and use its parameter ordering.
- val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
- s => s.isMethod && s.asMethod.isPrimaryConstructor)
+ val primaryConstructorSymbol: Option[Symbol] =
+ constructorSymbol.asTerm.alternatives.find(s =>
+ s.isMethod && s.asMethod.isPrimaryConstructor)
if (primaryConstructorSymbol.isEmpty) {
throw ExecutionErrors.primaryConstructorNotFoundError(tpe.getClass)
} else {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala
index cbf1f01344c92..a81c071295e2c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst
/**
- * This class records the paths the serializer and deserializer walk through to reach current path.
- * Note that this class adds new path in prior to recorded paths so it maintains
- * the paths as reverse order.
+ * This class records the paths the serializer and deserializer walk through to reach current
+ * path. Note that this class adds new path in prior to recorded paths so it maintains the paths
+ * as reverse order.
*/
case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) {
def recordRoot(className: String): WalkedTypePath =
@@ -33,7 +33,8 @@ case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) {
newInstance(s"""- array element class: "$elementClassName"""")
def recordMap(keyClassName: String, valueClassName: String): WalkedTypePath = {
- newInstance(s"""- map key class: "$keyClassName"""" +
+ newInstance(
+ s"""- map key class: "$keyClassName"""" +
s""", value class: "$valueClassName"""")
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala
index 9955f1b7bd301..913881f326c90 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala
@@ -20,20 +20,17 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.QuotingUtils.quoted
-
/**
- * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception
- * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
+ * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception as an
+ * [[org.apache.spark.sql.AnalysisException]] with the correct position information.
*/
case class NonEmptyNamespaceException(
namespace: Array[String],
details: String,
override val cause: Option[Throwable] = None)
- extends AnalysisException(
- errorClass = "_LEGACY_ERROR_TEMP_3103",
- messageParameters = Map(
- "namespace" -> quoted(namespace),
- "details" -> details)) {
+ extends AnalysisException(
+ errorClass = "_LEGACY_ERROR_TEMP_3103",
+ messageParameters = Map("namespace" -> quoted(namespace), "details" -> details)) {
def this(namespace: Array[String]) = this(namespace, "", None)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala
index 9f5a5b8875b33..f218a12209d61 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/SqlApiAnalysis.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.analysis
object SqlApiAnalysis {
+
/**
- * Resolver should return true if the first string refers to the same entity as the second string.
- * For example, by using case insensitive equality.
+ * Resolver should return true if the first string refers to the same entity as the second
+ * string. For example, by using case insensitive equality.
*/
type Resolver = (String, String) => Boolean
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala
index fae3711baf706..0c667dd8916ee 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/alreadyExistException.scala
@@ -25,21 +25,21 @@ import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.util.ArrayImplicits._
/**
- * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception
- * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
+ * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception as an
+ * [[org.apache.spark.sql.AnalysisException]] with the correct position information.
*/
class DatabaseAlreadyExistsException(db: String)
- extends NamespaceAlreadyExistsException(Array(db))
+ extends NamespaceAlreadyExistsException(Array(db))
// any changes to this class should be backward compatible as it may be used by external connectors
-class NamespaceAlreadyExistsException private(
+class NamespaceAlreadyExistsException private (
message: String,
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
def this(errorClass: String, messageParameters: Map[String, String]) = {
this(
@@ -49,24 +49,28 @@ class NamespaceAlreadyExistsException private(
}
def this(namespace: Array[String]) = {
- this(errorClass = "SCHEMA_ALREADY_EXISTS",
+ this(
+ errorClass = "SCHEMA_ALREADY_EXISTS",
Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq)))
}
}
// any changes to this class should be backward compatible as it may be used by external connectors
-class TableAlreadyExistsException private(
+class TableAlreadyExistsException private (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
- def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = {
+ def this(
+ errorClass: String,
+ messageParameters: Map[String, String],
+ cause: Option[Throwable]) = {
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
cause,
@@ -75,47 +79,53 @@ class TableAlreadyExistsException private(
}
def this(db: String, table: String) = {
- this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
- messageParameters = Map("relationName" ->
- (quoteIdentifier(db) + "." + quoteIdentifier(table))),
+ this(
+ errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
+ messageParameters = Map(
+ "relationName" ->
+ (quoteIdentifier(db) + "." + quoteIdentifier(table))),
cause = None)
}
def this(table: String) = {
- this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
- messageParameters = Map("relationName" ->
- quoteNameParts(AttributeNameParser.parseAttributeName(table))),
+ this(
+ errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
+ messageParameters = Map(
+ "relationName" ->
+ quoteNameParts(AttributeNameParser.parseAttributeName(table))),
cause = None)
}
def this(table: Seq[String]) = {
- this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
+ this(
+ errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
messageParameters = Map("relationName" -> quoteNameParts(table)),
cause = None)
}
def this(tableIdent: Identifier) = {
- this(errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
+ this(
+ errorClass = "TABLE_OR_VIEW_ALREADY_EXISTS",
messageParameters = Map("relationName" -> quoted(tableIdent)),
cause = None)
}
}
-class TempTableAlreadyExistsException private(
+class TempTableAlreadyExistsException private (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
def this(
- errorClass: String,
- messageParameters: Map[String, String],
- cause: Option[Throwable] = None) = {
+ errorClass: String,
+ messageParameters: Map[String, String],
+ cause: Option[Throwable] = None) = {
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
cause,
@@ -126,27 +136,31 @@ class TempTableAlreadyExistsException private(
def this(table: String) = {
this(
errorClass = "TEMP_TABLE_OR_VIEW_ALREADY_EXISTS",
- messageParameters = Map("relationName"
- -> quoteNameParts(AttributeNameParser.parseAttributeName(table))))
+ messageParameters = Map(
+ "relationName"
+ -> quoteNameParts(AttributeNameParser.parseAttributeName(table))))
}
}
// any changes to this class should be backward compatible as it may be used by external connectors
class ViewAlreadyExistsException(errorClass: String, messageParameters: Map[String, String])
- extends AnalysisException(errorClass, messageParameters) {
+ extends AnalysisException(errorClass, messageParameters) {
def this(ident: Identifier) =
- this(errorClass = "VIEW_ALREADY_EXISTS",
+ this(
+ errorClass = "VIEW_ALREADY_EXISTS",
messageParameters = Map("relationName" -> quoted(ident)))
}
// any changes to this class should be backward compatible as it may be used by external connectors
class FunctionAlreadyExistsException(errorClass: String, messageParameters: Map[String, String])
- extends AnalysisException(errorClass, messageParameters) {
+ extends AnalysisException(errorClass, messageParameters) {
def this(function: Seq[String]) = {
- this (errorClass = "ROUTINE_ALREADY_EXISTS",
- Map("routineName" -> quoteNameParts(function),
+ this(
+ errorClass = "ROUTINE_ALREADY_EXISTS",
+ Map(
+ "routineName" -> quoteNameParts(function),
"newRoutineType" -> "routine",
"existingRoutineType" -> "routine"))
}
@@ -157,16 +171,16 @@ class FunctionAlreadyExistsException(errorClass: String, messageParameters: Map[
}
// any changes to this class should be backward compatible as it may be used by external connectors
-class IndexAlreadyExistsException private(
+class IndexAlreadyExistsException private (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
def this(
errorClass: String,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala
index 8977d0be24d77..dbc7622c761e7 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/analysis/noSuchItemsExceptions.scala
@@ -24,21 +24,24 @@ import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.util.ArrayImplicits._
/**
- * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception
- * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
+ * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception as an
+ * [[org.apache.spark.sql.AnalysisException]] with the correct position information.
*/
-class NoSuchDatabaseException private[analysis](
+class NoSuchDatabaseException private[analysis] (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
- def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = {
+ def this(
+ errorClass: String,
+ messageParameters: Map[String, String],
+ cause: Option[Throwable]) = {
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
cause = cause,
@@ -55,16 +58,16 @@ class NoSuchDatabaseException private[analysis](
}
// any changes to this class should be backward compatible as it may be used by external connectors
-class NoSuchNamespaceException private(
+class NoSuchNamespaceException private (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends NoSuchDatabaseException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends NoSuchDatabaseException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
def this(errorClass: String, messageParameters: Map[String, String]) = {
this(
@@ -75,29 +78,32 @@ class NoSuchNamespaceException private(
}
def this(namespace: Seq[String]) = {
- this(errorClass = "SCHEMA_NOT_FOUND",
- Map("schemaName" -> quoteNameParts(namespace)))
+ this(errorClass = "SCHEMA_NOT_FOUND", Map("schemaName" -> quoteNameParts(namespace)))
}
def this(namespace: Array[String]) = {
- this(errorClass = "SCHEMA_NOT_FOUND",
+ this(
+ errorClass = "SCHEMA_NOT_FOUND",
Map("schemaName" -> quoteNameParts(namespace.toImmutableArraySeq)))
}
}
// any changes to this class should be backward compatible as it may be used by external connectors
-class NoSuchTableException private(
+class NoSuchTableException private (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
- def this(errorClass: String, messageParameters: Map[String, String], cause: Option[Throwable]) = {
+ def this(
+ errorClass: String,
+ messageParameters: Map[String, String],
+ cause: Option[Throwable]) = {
this(
SparkThrowableHelper.getMessage(errorClass, messageParameters),
cause = cause,
@@ -108,12 +114,13 @@ class NoSuchTableException private(
def this(db: String, table: String) = {
this(
errorClass = "TABLE_OR_VIEW_NOT_FOUND",
- messageParameters = Map("relationName" ->
- (quoteIdentifier(db) + "." + quoteIdentifier(table))),
+ messageParameters = Map(
+ "relationName" ->
+ (quoteIdentifier(db) + "." + quoteIdentifier(table))),
cause = None)
}
- def this(name : Seq[String]) = {
+ def this(name: Seq[String]) = {
this(
errorClass = "TABLE_OR_VIEW_NOT_FOUND",
messageParameters = Map("relationName" -> quoteNameParts(name)),
@@ -130,28 +137,28 @@ class NoSuchTableException private(
// any changes to this class should be backward compatible as it may be used by external connectors
class NoSuchViewException(errorClass: String, messageParameters: Map[String, String])
- extends AnalysisException(errorClass, messageParameters) {
+ extends AnalysisException(errorClass, messageParameters) {
def this(ident: Identifier) =
- this(errorClass = "VIEW_NOT_FOUND",
- messageParameters = Map("relationName" -> quoted(ident)))
+ this(errorClass = "VIEW_NOT_FOUND", messageParameters = Map("relationName" -> quoted(ident)))
}
class NoSuchPermanentFunctionException(db: String, func: String)
- extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND",
- Map("routineName" -> (quoteIdentifier(db) + "." + quoteIdentifier(func))))
+ extends AnalysisException(
+ errorClass = "ROUTINE_NOT_FOUND",
+ Map("routineName" -> (quoteIdentifier(db) + "." + quoteIdentifier(func))))
// any changes to this class should be backward compatible as it may be used by external connectors
-class NoSuchFunctionException private(
+class NoSuchFunctionException private (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
def this(errorClass: String, messageParameters: Map[String, String]) = {
this(
@@ -162,7 +169,8 @@ class NoSuchFunctionException private(
}
def this(db: String, func: String) = {
- this(errorClass = "ROUTINE_NOT_FOUND",
+ this(
+ errorClass = "ROUTINE_NOT_FOUND",
Map("routineName" -> (quoteIdentifier(db) + "." + quoteIdentifier(func))))
}
@@ -172,19 +180,19 @@ class NoSuchFunctionException private(
}
class NoSuchTempFunctionException(func: String)
- extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> s"`$func`"))
+ extends AnalysisException(errorClass = "ROUTINE_NOT_FOUND", Map("routineName" -> s"`$func`"))
// any changes to this class should be backward compatible as it may be used by external connectors
-class NoSuchIndexException private(
+class NoSuchIndexException private (
message: String,
cause: Option[Throwable],
errorClass: Option[String],
messageParameters: Map[String, String])
- extends AnalysisException(
- message,
- cause = cause,
- errorClass = errorClass,
- messageParameters = messageParameters) {
+ extends AnalysisException(
+ message,
+ cause = cause,
+ errorClass = errorClass,
+ messageParameters = messageParameters) {
def this(
errorClass: String,
@@ -201,3 +209,14 @@ class NoSuchIndexException private(
this("INDEX_NOT_FOUND", Map("indexName" -> indexName, "tableName" -> tableName), cause)
}
}
+
+class CannotReplaceMissingTableException(
+ tableIdentifier: Identifier,
+ cause: Option[Throwable] = None)
+ extends AnalysisException(
+ errorClass = "TABLE_OR_VIEW_NOT_FOUND",
+ messageParameters = Map(
+ "relationName"
+ -> quoteNameParts(
+ (tableIdentifier.namespace :+ tableIdentifier.name).toImmutableArraySeq)),
+ cause = cause)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 639b23f714149..10f734b3f84ed 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -19,24 +19,23 @@ package org.apache.spark.sql.catalyst.encoders
import java.{sql => jsql}
import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
-import java.util.concurrent.ConcurrentHashMap
import scala.reflect.{classTag, ClassTag}
import org.apache.spark.sql.{Encoder, Row}
+import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.util.SparkClassUtils
/**
- * A non implementation specific encoder. This encoder containers all the information needed
- * to generate an implementation specific encoder (e.g. InternalRow <=> Custom Object).
+ * A non implementation specific encoder. This encoder containers all the information needed to
+ * generate an implementation specific encoder (e.g. InternalRow <=> Custom Object).
*
* The input of the serialization does not need to match the external type of the encoder. This is
- * called lenient serialization. An example of this is lenient date serialization, in this case both
- * [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization is never lenient; it
- * will always produce instance of the external type.
- *
+ * called lenient serialization. An example of this is lenient date serialization, in this case
+ * both [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization is never
+ * lenient; it will always produce instance of the external type.
*/
trait AgnosticEncoder[T] extends Encoder[T] {
def isPrimitive: Boolean
@@ -47,16 +46,29 @@ trait AgnosticEncoder[T] extends Encoder[T] {
def isStruct: Boolean = false
}
+/**
+ * Extract an [[AgnosticEncoder]] from an [[Encoder]].
+ */
+trait ToAgnosticEncoder[T] {
+ def encoder: AgnosticEncoder[T]
+}
+
object AgnosticEncoders {
+ def agnosticEncoderFor[T: Encoder]: AgnosticEncoder[T] = implicitly[Encoder[T]] match {
+ case a: AgnosticEncoder[T] => a
+ case e: ToAgnosticEncoder[T @unchecked] => e.encoder
+ case other => throw ExecutionErrors.invalidAgnosticEncoderError(other)
+ }
+
case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E])
- extends AgnosticEncoder[Option[E]] {
+ extends AgnosticEncoder[Option[E]] {
override def isPrimitive: Boolean = false
override def dataType: DataType = elementEncoder.dataType
override val clsTag: ClassTag[Option[E]] = ClassTag(classOf[Option[E]])
}
case class ArrayEncoder[E](element: AgnosticEncoder[E], containsNull: Boolean)
- extends AgnosticEncoder[Array[E]] {
+ extends AgnosticEncoder[Array[E]] {
override def isPrimitive: Boolean = false
override def dataType: DataType = ArrayType(element.dataType, containsNull)
override val clsTag: ClassTag[Array[E]] = element.clsTag.wrap
@@ -73,7 +85,7 @@ object AgnosticEncoders {
element: AgnosticEncoder[E],
containsNull: Boolean,
override val lenientSerialization: Boolean)
- extends AgnosticEncoder[C] {
+ extends AgnosticEncoder[C] {
override def isPrimitive: Boolean = false
override val dataType: DataType = ArrayType(element.dataType, containsNull)
}
@@ -83,12 +95,10 @@ object AgnosticEncoders {
keyEncoder: AgnosticEncoder[K],
valueEncoder: AgnosticEncoder[V],
valueContainsNull: Boolean)
- extends AgnosticEncoder[C] {
+ extends AgnosticEncoder[C] {
override def isPrimitive: Boolean = false
- override val dataType: DataType = MapType(
- keyEncoder.dataType,
- valueEncoder.dataType,
- valueContainsNull)
+ override val dataType: DataType =
+ MapType(keyEncoder.dataType, valueEncoder.dataType, valueContainsNull)
}
case class EncoderField(
@@ -114,17 +124,28 @@ object AgnosticEncoders {
case class ProductEncoder[K](
override val clsTag: ClassTag[K],
override val fields: Seq[EncoderField],
- outerPointerGetter: Option[() => AnyRef]) extends StructEncoder[K]
+ outerPointerGetter: Option[() => AnyRef])
+ extends StructEncoder[K]
object ProductEncoder {
- val cachedCls = new ConcurrentHashMap[Int, Class[_]]
- private[sql] def tuple(encoders: Seq[AgnosticEncoder[_]]): AgnosticEncoder[_] = {
- val fields = encoders.zipWithIndex.map {
- case (e, id) => EncoderField(s"_${id + 1}", e, e.nullable, Metadata.empty)
+ private val MAX_TUPLE_ELEMENTS = 22
+
+ private val tupleClassTags = Array.tabulate[ClassTag[Any]](MAX_TUPLE_ELEMENTS + 1) {
+ case 0 => null
+ case i => ClassTag(SparkClassUtils.classForName(s"scala.Tuple$i"))
+ }
+
+ private[sql] def tuple(
+ encoders: Seq[AgnosticEncoder[_]],
+ elementsCanBeNull: Boolean = false): AgnosticEncoder[_] = {
+ val numElements = encoders.size
+ if (numElements < 1 || numElements > MAX_TUPLE_ELEMENTS) {
+ throw ExecutionErrors.elementsOfTupleExceedLimitError()
}
- val cls = cachedCls.computeIfAbsent(encoders.size,
- _ => SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}"))
- ProductEncoder[Any](ClassTag(cls), fields, None)
+ val fields = encoders.zipWithIndex.map { case (e, id) =>
+ EncoderField(s"_${id + 1}", e, e.nullable || elementsCanBeNull, Metadata.empty)
+ }
+ ProductEncoder[Any](tupleClassTags(numElements), fields, None)
}
private[sql] def isTuple(tag: ClassTag[_]): Boolean = {
@@ -141,19 +162,19 @@ object AgnosticEncoders {
object UnboundRowEncoder extends BaseRowEncoder {
override val schema: StructType = new StructType()
override val fields: Seq[EncoderField] = Seq.empty
-}
+ }
case class JavaBeanEncoder[K](
override val clsTag: ClassTag[K],
override val fields: Seq[EncoderField])
- extends StructEncoder[K]
+ extends StructEncoder[K]
// This will only work for encoding from/to Sparks' InternalRow format.
// It is here for compatibility.
case class UDTEncoder[E >: Null](
udt: UserDefinedType[E],
udtClass: Class[_ <: UserDefinedType[_]])
- extends AgnosticEncoder[E] {
+ extends AgnosticEncoder[E] {
override def isPrimitive: Boolean = false
override def dataType: DataType = udt
override def clsTag: ClassTag[E] = ClassTag(udt.userClass)
@@ -164,21 +185,19 @@ object AgnosticEncoders {
override def isPrimitive: Boolean = false
override def dataType: DataType = StringType
}
- case class ScalaEnumEncoder[T, E](
- parent: Class[T],
- override val clsTag: ClassTag[E])
- extends EnumEncoder[E]
+ case class ScalaEnumEncoder[T, E](parent: Class[T], override val clsTag: ClassTag[E])
+ extends EnumEncoder[E]
case class JavaEnumEncoder[E](override val clsTag: ClassTag[E]) extends EnumEncoder[E]
- protected abstract class LeafEncoder[E : ClassTag](override val dataType: DataType)
- extends AgnosticEncoder[E] {
+ protected abstract class LeafEncoder[E: ClassTag](override val dataType: DataType)
+ extends AgnosticEncoder[E] {
override val clsTag: ClassTag[E] = classTag[E]
override val isPrimitive: Boolean = clsTag.runtimeClass.isPrimitive
}
// Primitive encoders
- abstract class PrimitiveLeafEncoder[E : ClassTag](dataType: DataType)
- extends LeafEncoder[E](dataType)
+ abstract class PrimitiveLeafEncoder[E: ClassTag](dataType: DataType)
+ extends LeafEncoder[E](dataType)
case object PrimitiveBooleanEncoder extends PrimitiveLeafEncoder[Boolean](BooleanType)
case object PrimitiveByteEncoder extends PrimitiveLeafEncoder[Byte](ByteType)
case object PrimitiveShortEncoder extends PrimitiveLeafEncoder[Short](ShortType)
@@ -188,24 +207,24 @@ object AgnosticEncoders {
case object PrimitiveDoubleEncoder extends PrimitiveLeafEncoder[Double](DoubleType)
// Primitive wrapper encoders.
- abstract class BoxedLeafEncoder[E : ClassTag, P](
+ abstract class BoxedLeafEncoder[E: ClassTag, P](
dataType: DataType,
val primitive: PrimitiveLeafEncoder[P])
- extends LeafEncoder[E](dataType)
+ extends LeafEncoder[E](dataType)
case object BoxedBooleanEncoder
- extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType, PrimitiveBooleanEncoder)
+ extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType, PrimitiveBooleanEncoder)
case object BoxedByteEncoder
- extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType, PrimitiveByteEncoder)
+ extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType, PrimitiveByteEncoder)
case object BoxedShortEncoder
- extends BoxedLeafEncoder[java.lang.Short, Short](ShortType, PrimitiveShortEncoder)
+ extends BoxedLeafEncoder[java.lang.Short, Short](ShortType, PrimitiveShortEncoder)
case object BoxedIntEncoder
- extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType, PrimitiveIntEncoder)
+ extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType, PrimitiveIntEncoder)
case object BoxedLongEncoder
- extends BoxedLeafEncoder[java.lang.Long, Long](LongType, PrimitiveLongEncoder)
+ extends BoxedLeafEncoder[java.lang.Long, Long](LongType, PrimitiveLongEncoder)
case object BoxedFloatEncoder
- extends BoxedLeafEncoder[java.lang.Float, Float](FloatType, PrimitiveFloatEncoder)
+ extends BoxedLeafEncoder[java.lang.Float, Float](FloatType, PrimitiveFloatEncoder)
case object BoxedDoubleEncoder
- extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType, PrimitiveDoubleEncoder)
+ extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType, PrimitiveDoubleEncoder)
// Nullable leaf encoders
case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
@@ -218,19 +237,19 @@ object AgnosticEncoders {
case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType())
case object VariantEncoder extends LeafEncoder[VariantVal](VariantType)
case class DateEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[jsql.Date](DateType)
+ extends LeafEncoder[jsql.Date](DateType)
case class LocalDateEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[LocalDate](DateType)
+ extends LeafEncoder[LocalDate](DateType)
case class TimestampEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[jsql.Timestamp](TimestampType)
+ extends LeafEncoder[jsql.Timestamp](TimestampType)
case class InstantEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[Instant](TimestampType)
+ extends LeafEncoder[Instant](TimestampType)
case object LocalDateTimeEncoder extends LeafEncoder[LocalDateTime](TimestampNTZType)
case class SparkDecimalEncoder(dt: DecimalType) extends LeafEncoder[Decimal](dt)
case class ScalaDecimalEncoder(dt: DecimalType) extends LeafEncoder[BigDecimal](dt)
case class JavaDecimalEncoder(dt: DecimalType, override val lenientSerialization: Boolean)
- extends LeafEncoder[JBigDecimal](dt)
+ extends LeafEncoder[JBigDecimal](dt)
val STRICT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = false)
val STRICT_LOCAL_DATE_ENCODER: LocalDateEncoder = LocalDateEncoder(lenientSerialization = false)
@@ -257,7 +276,8 @@ object AgnosticEncoders {
case class TransformingEncoder[I, O](
clsTag: ClassTag[I],
transformed: AgnosticEncoder[O],
- codecProvider: () => Codec[_ >: I, O]) extends AgnosticEncoder[I] {
+ codecProvider: () => Codec[_ >: I, O])
+ extends AgnosticEncoder[I] {
override def isPrimitive: Boolean = transformed.isPrimitive
override def dataType: DataType = transformed.dataType
override def schema: StructType = transformed.schema
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index 8587688956950..909556492847f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -48,10 +48,9 @@ object OuterScopes {
/**
* Adds a new outer scope to this context that can be used when instantiating an `inner class`
- * during deserialization. Inner classes are created when a case class is defined in the
- * Spark REPL and registering the outer scope that this class was defined in allows us to create
- * new instances on the spark executors. In normal use, users should not need to call this
- * function.
+ * during deserialization. Inner classes are created when a case class is defined in the Spark
+ * REPL and registering the outer scope that this class was defined in allows us to create new
+ * instances on the spark executors. In normal use, users should not need to call this function.
*
* Warning: this function operates on the assumption that there is only ever one instance of any
* given wrapper class.
@@ -65,7 +64,7 @@ object OuterScopes {
}
/**
- * Returns a function which can get the outer scope for the given inner class. By using function
+ * Returns a function which can get the outer scope for the given inner class. By using function
* as return type, we can delay the process of getting outer pointer to execution time, which is
* useful for inner class defined in REPL.
*/
@@ -134,8 +133,8 @@ object OuterScopes {
}
case _ => null
}
- } else {
- () => outer
+ } else { () =>
+ outer
}
}
@@ -162,7 +161,7 @@ object OuterScopes {
* dead entries after GC (using a [[ReferenceQueue]]).
*/
private[catalyst] class HashableWeakReference(v: AnyRef, queue: ReferenceQueue[AnyRef])
- extends WeakReference[AnyRef](v, queue) {
+ extends WeakReference[AnyRef](v, queue) {
def this(v: AnyRef) = this(v, null)
private[this] val hash = v.hashCode()
override def hashCode(): Int = hash
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index c507e952630f6..8b6da805a6e87 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
/**
- * A factory for constructing encoders that convert external row to/from the Spark SQL
- * internal binary representation.
+ * A factory for constructing encoders that convert external row to/from the Spark SQL internal
+ * binary representation.
*
* The following is a mapping between Spark SQL types and its allowed external types:
* {{{
@@ -68,67 +68,65 @@ object RowEncoder extends DataTypeErrorsBase {
encoderForDataType(schema, lenient).asInstanceOf[AgnosticEncoder[Row]]
}
- private[sql] def encoderForDataType(
- dataType: DataType,
- lenient: Boolean): AgnosticEncoder[_] = dataType match {
- case NullType => NullEncoder
- case BooleanType => BoxedBooleanEncoder
- case ByteType => BoxedByteEncoder
- case ShortType => BoxedShortEncoder
- case IntegerType => BoxedIntEncoder
- case LongType => BoxedLongEncoder
- case FloatType => BoxedFloatEncoder
- case DoubleType => BoxedDoubleEncoder
- case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
- case BinaryType => BinaryEncoder
- case _: StringType => StringEncoder
- case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient)
- case TimestampType => TimestampEncoder(lenient)
- case TimestampNTZType => LocalDateTimeEncoder
- case DateType if SqlApiConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient)
- case DateType => DateEncoder(lenient)
- case CalendarIntervalType => CalendarIntervalEncoder
- case _: DayTimeIntervalType => DayTimeIntervalEncoder
- case _: YearMonthIntervalType => YearMonthIntervalEncoder
- case _: VariantType => VariantEncoder
- case p: PythonUserDefinedType =>
- // TODO check if this works.
- encoderForDataType(p.sqlType, lenient)
- case udt: UserDefinedType[_] =>
- val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
- val udtClass: Class[_] = if (annotation != null) {
- annotation.udt()
- } else {
- UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
- throw ExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
+ private[sql] def encoderForDataType(dataType: DataType, lenient: Boolean): AgnosticEncoder[_] =
+ dataType match {
+ case NullType => NullEncoder
+ case BooleanType => BoxedBooleanEncoder
+ case ByteType => BoxedByteEncoder
+ case ShortType => BoxedShortEncoder
+ case IntegerType => BoxedIntEncoder
+ case LongType => BoxedLongEncoder
+ case FloatType => BoxedFloatEncoder
+ case DoubleType => BoxedDoubleEncoder
+ case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
+ case BinaryType => BinaryEncoder
+ case _: StringType => StringEncoder
+ case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient)
+ case TimestampType => TimestampEncoder(lenient)
+ case TimestampNTZType => LocalDateTimeEncoder
+ case DateType if SqlApiConf.get.datetimeJava8ApiEnabled => LocalDateEncoder(lenient)
+ case DateType => DateEncoder(lenient)
+ case CalendarIntervalType => CalendarIntervalEncoder
+ case _: DayTimeIntervalType => DayTimeIntervalEncoder
+ case _: YearMonthIntervalType => YearMonthIntervalEncoder
+ case _: VariantType => VariantEncoder
+ case p: PythonUserDefinedType =>
+ // TODO check if this works.
+ encoderForDataType(p.sqlType, lenient)
+ case udt: UserDefinedType[_] =>
+ val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
+ val udtClass: Class[_] = if (annotation != null) {
+ annotation.udt()
+ } else {
+ UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
+ throw ExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
+ }
}
- }
- UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]])
- case ArrayType(elementType, containsNull) =>
- IterableEncoder(
- classTag[mutable.ArraySeq[_]],
- encoderForDataType(elementType, lenient),
- containsNull,
- lenientSerialization = true)
- case MapType(keyType, valueType, valueContainsNull) =>
- MapEncoder(
- classTag[scala.collection.Map[_, _]],
- encoderForDataType(keyType, lenient),
- encoderForDataType(valueType, lenient),
- valueContainsNull)
- case StructType(fields) =>
- AgnosticRowEncoder(fields.map { field =>
- EncoderField(
- field.name,
- encoderForDataType(field.dataType, lenient),
- field.nullable,
- field.metadata)
- }.toImmutableArraySeq)
+ UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]])
+ case ArrayType(elementType, containsNull) =>
+ IterableEncoder(
+ classTag[mutable.ArraySeq[_]],
+ encoderForDataType(elementType, lenient),
+ containsNull,
+ lenientSerialization = true)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapEncoder(
+ classTag[scala.collection.Map[_, _]],
+ encoderForDataType(keyType, lenient),
+ encoderForDataType(valueType, lenient),
+ valueContainsNull)
+ case StructType(fields) =>
+ AgnosticRowEncoder(fields.map { field =>
+ EncoderField(
+ field.name,
+ encoderForDataType(field.dataType, lenient),
+ field.nullable,
+ field.metadata)
+ }.toImmutableArraySeq)
- case _ =>
- throw new AnalysisException(
- errorClass = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER",
- messageParameters = Map("dataType" -> toSQLType(dataType))
- )
- }
+ case _ =>
+ throw new AnalysisException(
+ errorClass = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER",
+ messageParameters = Map("dataType" -> toSQLType(dataType)))
+ }
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala
index 46862ebbccdfd..0f21972552339 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala
@@ -16,15 +16,20 @@
*/
package org.apache.spark.sql.catalyst.encoders
-import org.apache.spark.util.SparkSerDeUtils
+import java.lang.invoke.{MethodHandle, MethodHandles, MethodType}
+
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils}
/**
* Codec for doing conversions between two representations.
*
- * @tparam I input type (typically the external representation of the data.
- * @tparam O output type (typically the internal representation of the data.
+ * @tparam I
+ * input type (typically the external representation of the data.
+ * @tparam O
+ * output type (typically the internal representation of the data.
*/
-trait Codec[I, O] {
+trait Codec[I, O] extends Serializable {
def encode(in: I): O
def decode(out: O): I
}
@@ -40,3 +45,29 @@ class JavaSerializationCodec[I] extends Codec[I, Array[Byte]] {
object JavaSerializationCodec extends (() => Codec[Any, Array[Byte]]) {
override def apply(): Codec[Any, Array[Byte]] = new JavaSerializationCodec[Any]
}
+
+/**
+ * A codec that uses Kryo to (de)serialize arbitrary objects to and from a byte array.
+ *
+ * Please note that this is currently only supported for Classic Spark applications. The reason
+ * for this is that Connect applications can have a significantly different classpath than the
+ * driver or executor. This makes having a the same Kryo configuration on both the client and
+ * server (driver & executors) very tricky. As a workaround a user can define their own Codec
+ * which internalizes the Kryo configuration.
+ */
+object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) {
+ private lazy val kryoCodecConstructor: MethodHandle = {
+ val cls = SparkClassUtils.classForName(
+ "org.apache.spark.sql.catalyst.encoders.KryoSerializationCodecImpl")
+ MethodHandles.lookup().findConstructor(cls, MethodType.methodType(classOf[Unit]))
+ }
+
+ override def apply(): Codec[Any, Array[Byte]] = {
+ try {
+ kryoCodecConstructor.invoke().asInstanceOf[Codec[Any, Array[Byte]]]
+ } catch {
+ case _: ClassNotFoundException =>
+ throw ExecutionErrors.cannotUseKryoSerialization()
+ }
+ }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala
index 385e0f00695a3..76442accbd35b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/OrderUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, NullType, StructType, UserDefinedType, VariantType}
object OrderUtils {
+
/**
* Returns true iff the data type can be ordered (i.e. can be sorted).
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 7f21ab25ad4e5..6977d9f3185ac 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -21,11 +21,12 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._
/**
- * A row implementation that uses an array of objects as the underlying storage. Note that, while
+ * A row implementation that uses an array of objects as the underlying storage. Note that, while
* the array is not copied, and thus could technically be mutated after creation, this is not
* allowed.
*/
class GenericRow(protected[sql] val values: Array[Any]) extends Row {
+
/** No-arg constructor for serialization. */
protected def this() = this(null)
@@ -41,7 +42,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
- extends GenericRow(values) {
+ extends GenericRow(values) {
/** No-arg constructor for serialization. */
protected def this() = this(null, null)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
index 38ecd29266db7..46fb4a3397c59 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
@@ -120,7 +120,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = {
val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
val start = DayTimeIntervalType.stringToField(startStr)
- if (ctx.to != null ) {
+ if (ctx.to != null) {
val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
val end = DayTimeIntervalType.stringToField(endStr)
if (end <= start) {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala
index ab665f360b0a6..1a1a9b01de3b1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala
@@ -22,9 +22,10 @@ import org.apache.spark.sql.types.{DataType, StructType}
* Interface for [[DataType]] parsing functionality.
*/
trait DataTypeParserInterface {
+
/**
- * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
- * of field definitions which will preserve the correct Hive metadata.
+ * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list of
+ * field definitions which will preserve the correct Hive metadata.
*/
@throws[ParseException]("Text cannot be parsed to a schema")
def parseTableSchema(sqlText: String): StructType
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala
index 8ac5939bca944..7d1986e727f79 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/LegacyTypeStringParser.scala
@@ -23,35 +23,36 @@ import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.types._
/**
- * Parser that turns case class strings into datatypes. This is only here to maintain compatibility
- * with Parquet files written by Spark 1.1 and below.
+ * Parser that turns case class strings into datatypes. This is only here to maintain
+ * compatibility with Parquet files written by Spark 1.1 and below.
*/
object LegacyTypeStringParser extends RegexParsers {
protected lazy val primitiveType: Parser[DataType] =
- ( "StringType" ^^^ StringType
- | "FloatType" ^^^ FloatType
- | "IntegerType" ^^^ IntegerType
- | "ByteType" ^^^ ByteType
- | "ShortType" ^^^ ShortType
- | "DoubleType" ^^^ DoubleType
- | "LongType" ^^^ LongType
- | "BinaryType" ^^^ BinaryType
- | "BooleanType" ^^^ BooleanType
- | "DateType" ^^^ DateType
- | "DecimalType()" ^^^ DecimalType.USER_DEFAULT
- | fixedDecimalType
- | "TimestampType" ^^^ TimestampType
- )
+ (
+ "StringType" ^^^ StringType
+ | "FloatType" ^^^ FloatType
+ | "IntegerType" ^^^ IntegerType
+ | "ByteType" ^^^ ByteType
+ | "ShortType" ^^^ ShortType
+ | "DoubleType" ^^^ DoubleType
+ | "LongType" ^^^ LongType
+ | "BinaryType" ^^^ BinaryType
+ | "BooleanType" ^^^ BooleanType
+ | "DateType" ^^^ DateType
+ | "DecimalType()" ^^^ DecimalType.USER_DEFAULT
+ | fixedDecimalType
+ | "TimestampType" ^^^ TimestampType
+ )
protected lazy val fixedDecimalType: Parser[DataType] =
- ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ {
- case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { case precision ~ scale =>
+ DecimalType(precision.toInt, scale.toInt)
}
protected lazy val arrayType: Parser[DataType] =
- "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
- case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
+ "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { case tpe ~ _ ~ containsNull =>
+ ArrayType(tpe, containsNull)
}
protected lazy val mapType: Parser[DataType] =
@@ -66,21 +67,23 @@ object LegacyTypeStringParser extends RegexParsers {
}
protected lazy val boolVal: Parser[Boolean] =
- ( "true" ^^^ true
- | "false" ^^^ false
- )
+ (
+ "true" ^^^ true
+ | "false" ^^^ false
+ )
protected lazy val structType: Parser[DataType] =
- "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
- case fields => StructType(fields)
+ "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { case fields =>
+ StructType(fields)
}
protected lazy val dataType: Parser[DataType] =
- ( arrayType
- | mapType
- | structType
- | primitiveType
- )
+ (
+ arrayType
+ | mapType
+ | structType
+ | primitiveType
+ )
/**
* Parses a string representation of a DataType.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala
index 99e63d783838f..461d79ec22cf0 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SparkParserErrorStrategy.scala
@@ -32,7 +32,7 @@ class SparkRecognitionException(
ctx: ParserRuleContext,
val errorClass: Option[String] = None,
val messageParameters: Map[String, String] = Map.empty)
- extends RecognitionException(message, recognizer, input, ctx) {
+ extends RecognitionException(message, recognizer, input, ctx) {
/** Construct from a given [[RecognitionException]], with additional error information. */
def this(
@@ -50,7 +50,7 @@ class SparkRecognitionException(
Some(errorClass),
messageParameters)
- /** Construct with pure errorClass and messageParameter information. */
+ /** Construct with pure errorClass and messageParameter information. */
def this(errorClass: String, messageParameters: Map[String, String]) =
this("", null, null, null, Some(errorClass), messageParameters)
}
@@ -59,12 +59,12 @@ class SparkRecognitionException(
* A [[SparkParserErrorStrategy]] extends the [[DefaultErrorStrategy]], that does special handling
* on errors.
*
- * The intention of this class is to provide more information of these errors encountered in
- * ANTLR parser to the downstream consumers, to be able to apply the [[SparkThrowable]] error
- * message framework to these exceptions.
+ * The intention of this class is to provide more information of these errors encountered in ANTLR
+ * parser to the downstream consumers, to be able to apply the [[SparkThrowable]] error message
+ * framework to these exceptions.
*/
class SparkParserErrorStrategy() extends DefaultErrorStrategy {
- private val userWordDict : Map[String, String] = Map("''" -> "end of input")
+ private val userWordDict: Map[String, String] = Map("''" -> "end of input")
/** Get the user-facing display of the error token. */
override def getTokenErrorDisplay(t: Token): String = {
@@ -76,9 +76,7 @@ class SparkParserErrorStrategy() extends DefaultErrorStrategy {
val exceptionWithErrorClass = new SparkRecognitionException(
e,
"PARSE_SYNTAX_ERROR",
- messageParameters = Map(
- "error" -> getTokenErrorDisplay(e.getOffendingToken),
- "hint" -> ""))
+ messageParameters = Map("error" -> getTokenErrorDisplay(e.getOffendingToken), "hint" -> ""))
recognizer.notifyErrorListeners(e.getOffendingToken, "", exceptionWithErrorClass)
}
@@ -116,18 +114,17 @@ class SparkParserErrorStrategy() extends DefaultErrorStrategy {
/**
* Inspired by [[org.antlr.v4.runtime.BailErrorStrategy]], which is used in two-stage parsing:
- * This error strategy allows the first stage of two-stage parsing to immediately terminate
- * if an error is encountered, and immediately fall back to the second stage. In addition to
- * avoiding wasted work by attempting to recover from errors here, the empty implementation
- * of sync improves the performance of the first stage.
+ * This error strategy allows the first stage of two-stage parsing to immediately terminate if an
+ * error is encountered, and immediately fall back to the second stage. In addition to avoiding
+ * wasted work by attempting to recover from errors here, the empty implementation of sync
+ * improves the performance of the first stage.
*/
class SparkParserBailErrorStrategy() extends SparkParserErrorStrategy {
/**
- * Instead of recovering from exception e, re-throw it wrapped
- * in a [[ParseCancellationException]] so it is not caught by the
- * rule function catches. Use [[Exception#getCause]] to get the
- * original [[RecognitionException]].
+ * Instead of recovering from exception e, re-throw it wrapped in a
+ * [[ParseCancellationException]] so it is not caught by the rule function catches. Use
+ * [[Exception#getCause]] to get the original [[RecognitionException]].
*/
override def recover(recognizer: Parser, e: RecognitionException): Unit = {
var context = recognizer.getContext
@@ -139,8 +136,8 @@ class SparkParserBailErrorStrategy() extends SparkParserErrorStrategy {
}
/**
- * Make sure we don't attempt to recover inline; if the parser
- * successfully recovers, it won't throw an exception.
+ * Make sure we don't attempt to recover inline; if the parser successfully recovers, it won't
+ * throw an exception.
*/
@throws[RecognitionException]
override def recoverInline(recognizer: Parser): Token = {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
index 0b9e6ea021be1..10da24567545b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
* Base SQL parsing infrastructure.
*/
abstract class AbstractParser extends DataTypeParserInterface with Logging {
+
/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
@@ -78,8 +79,7 @@ abstract class AbstractParser extends DataTypeParserInterface with Logging {
parser.setErrorHandler(new SparkParserBailErrorStrategy())
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
toResult(parser)
- }
- catch {
+ } catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode w/ SparkParserErrorStrategy
tokenStream.seek(0) // rewind input stream
@@ -90,8 +90,7 @@ abstract class AbstractParser extends DataTypeParserInterface with Logging {
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
- }
- catch {
+ } catch {
case e: ParseException if e.command.isDefined =>
throw e
case e: ParseException =>
@@ -187,7 +186,7 @@ case object ParseErrorListener extends BaseErrorListener {
* A [[ParseException]] is an [[SparkException]] that is thrown during the parse process. It
* contains fields and an extended error message that make reporting and diagnosing errors easier.
*/
-class ParseException private(
+class ParseException private (
val command: Option[String],
message: String,
val start: Origin,
@@ -195,17 +194,18 @@ class ParseException private(
errorClass: Option[String] = None,
messageParameters: Map[String, String] = Map.empty,
queryContext: Array[QueryContext] = ParseException.getQueryContext())
- extends AnalysisException(
- message,
- start.line,
- start.startPosition,
- None,
- errorClass,
- messageParameters,
- queryContext) {
+ extends AnalysisException(
+ message,
+ start.line,
+ start.startPosition,
+ None,
+ errorClass,
+ messageParameters,
+ queryContext) {
def this(errorClass: String, messageParameters: Map[String, String], ctx: ParserRuleContext) =
- this(Option(SparkParserUtils.command(ctx)),
+ this(
+ Option(SparkParserUtils.command(ctx)),
SparkThrowableHelper.getMessage(errorClass, messageParameters),
SparkParserUtils.position(ctx.getStart),
SparkParserUtils.position(ctx.getStop),
@@ -310,14 +310,16 @@ case object PostProcessor extends SqlBaseParserBaseListener {
throw QueryParsingErrors.invalidIdentifierError(ident, ctx)
}
- /** Throws error message when unquoted identifier contains characters outside a-z, A-Z, 0-9, _ */
+ /**
+ * Throws error message when unquoted identifier contains characters outside a-z, A-Z, 0-9, _
+ */
override def exitUnquotedIdentifier(ctx: SqlBaseParser.UnquotedIdentifierContext): Unit = {
val ident = ctx.getText
if (ident.exists(c =>
- !(c >= 'a' && c <= 'z') &&
- !(c >= 'A' && c <= 'Z') &&
- !(c >= '0' && c <= '9') &&
- c != '_')) {
+ !(c >= 'a' && c <= 'z') &&
+ !(c >= 'A' && c <= 'Z') &&
+ !(c >= '0' && c <= '9') &&
+ c != '_')) {
throw QueryParsingErrors.invalidIdentifierError(ident, ctx)
}
}
@@ -353,9 +355,7 @@ case object PostProcessor extends SqlBaseParserBaseListener {
replaceTokenByIdentifier(ctx, 0)(identity)
}
- private def replaceTokenByIdentifier(
- ctx: ParserRuleContext,
- stripMargins: Int)(
+ private def replaceTokenByIdentifier(ctx: ParserRuleContext, stripMargins: Int)(
f: CommonToken => CommonToken = identity): Unit = {
val parent = ctx.getParent
parent.removeLastChild()
@@ -373,8 +373,8 @@ case object PostProcessor extends SqlBaseParserBaseListener {
/**
* The post-processor checks the unclosed bracketed comment.
*/
-case class UnclosedCommentProcessor(
- command: String, tokenStream: CommonTokenStream) extends SqlBaseParserBaseListener {
+case class UnclosedCommentProcessor(command: String, tokenStream: CommonTokenStream)
+ extends SqlBaseParserBaseListener {
override def exitSingleDataType(ctx: SqlBaseParser.SingleDataTypeContext): Unit = {
checkUnclosedComment(tokenStream, command)
@@ -384,7 +384,8 @@ case class UnclosedCommentProcessor(
checkUnclosedComment(tokenStream, command)
}
- override def exitSingleTableIdentifier(ctx: SqlBaseParser.SingleTableIdentifierContext): Unit = {
+ override def exitSingleTableIdentifier(
+ ctx: SqlBaseParser.SingleTableIdentifierContext): Unit = {
checkUnclosedComment(tokenStream, command)
}
@@ -422,7 +423,8 @@ case class UnclosedCommentProcessor(
// The last token is 'EOF' and the penultimate is unclosed bracketed comment
val failedToken = tokenStream.get(tokenStream.size() - 2)
assert(failedToken.getType() == SqlBaseParser.BRACKETED_COMMENT)
- val position = Origin(Option(failedToken.getLine), Option(failedToken.getCharPositionInLine))
+ val position =
+ Origin(Option(failedToken.getLine), Option(failedToken.getCharPositionInLine))
throw QueryParsingErrors.unclosedBracketedCommentError(
command = command,
start = Origin(Option(failedToken.getStartIndex)),
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
index 4c7c87504ffc4..e870a83ec4ae6 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeMode.scala
@@ -37,9 +37,10 @@ object TimeModes {
ProcessingTime
case "eventtime" =>
EventTime
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE",
- messageParameters = Map("timeMode" -> timeMode))
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE",
+ messageParameters = Map("timeMode" -> timeMode))
}
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
index 0ccfc3cbc7bf2..8c54758ebc0bc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala
@@ -28,27 +28,26 @@ import org.apache.spark.sql.streaming.OutputMode
private[sql] object InternalOutputModes {
/**
- * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be
- * written to the sink. This output mode can be only be used in queries that do not
- * contain any aggregation.
+ * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be written to
+ * the sink. This output mode can be only be used in queries that do not contain any
+ * aggregation.
*/
case object Append extends OutputMode
/**
- * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written
- * to the sink every time these is some updates. This output mode can only be used in queries
- * that contain aggregations.
+ * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written to the
+ * sink every time these is some updates. This output mode can only be used in queries that
+ * contain aggregations.
*/
case object Complete extends OutputMode
/**
- * OutputMode in which only the rows in the streaming DataFrame/Dataset that were updated will be
- * written to the sink every time these is some updates. If the query doesn't contain
+ * OutputMode in which only the rows in the streaming DataFrame/Dataset that were updated will
+ * be written to the sink every time these is some updates. If the query doesn't contain
* aggregations, it will be equivalent to `Append` mode.
*/
case object Update extends OutputMode
-
def apply(outputMode: String): OutputMode = {
outputMode.toLowerCase(Locale.ROOT) match {
case "append" =>
@@ -57,7 +56,8 @@ private[sql] object InternalOutputModes {
OutputMode.Complete
case "update" =>
OutputMode.Update
- case _ => throw new SparkIllegalArgumentException(
+ case _ =>
+ throw new SparkIllegalArgumentException(
errorClass = "STREAMING_OUTPUT_MODE.INVALID",
messageParameters = Map("outputMode" -> outputMode))
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
index 2b3f4674539e3..d5be65a2f36cf 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala
@@ -27,7 +27,8 @@ case class SQLQueryContext(
originStopIndex: Option[Int],
sqlText: Option[String],
originObjectType: Option[String],
- originObjectName: Option[String]) extends QueryContext {
+ originObjectName: Option[String])
+ extends QueryContext {
override val contextType = QueryContextType.SQL
override val objectType = originObjectType.getOrElse("")
@@ -37,9 +38,8 @@ case class SQLQueryContext(
/**
* The SQL query context of current node. For example:
- * == SQL of VIEW v1 (line 1, position 25) ==
- * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
- * ^^^^^^^^^^^^^^^
+ * ==SQL of VIEW v1 (line 1, position 25)==
+ * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i ^^^^^^^^^^^^^^^
*/
override lazy val summary: String = {
// If the query context is missing or incorrect, simply return an empty string.
@@ -127,8 +127,8 @@ case class SQLQueryContext(
def isValid: Boolean = {
sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined &&
- originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length &&
- originStartIndex.get <= originStopIndex.get
+ originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length &&
+ originStartIndex.get <= originStopIndex.get
}
override def callSite: String = throw SparkUnsupportedOperationException()
@@ -136,7 +136,8 @@ case class SQLQueryContext(
case class DataFrameQueryContext(
stackTrace: Seq[StackTraceElement],
- pysparkErrorContext: Option[(String, String)]) extends QueryContext {
+ pysparkErrorContext: Option[(String, String)])
+ extends QueryContext {
override val contextType = QueryContextType.DataFrame
override def objectType: String = throw SparkUnsupportedOperationException()
@@ -146,19 +147,21 @@ case class DataFrameQueryContext(
override val fragment: String = {
pysparkErrorContext.map(_._1).getOrElse {
- stackTrace.headOption.map { firstElem =>
- val methodName = firstElem.getMethodName
- if (methodName.length > 1 && methodName(0) == '$') {
- methodName.substring(1)
- } else {
- methodName
+ stackTrace.headOption
+ .map { firstElem =>
+ val methodName = firstElem.getMethodName
+ if (methodName.length > 1 && methodName(0) == '$') {
+ methodName.substring(1)
+ } else {
+ methodName
+ }
}
- }.getOrElse("")
+ .getOrElse("")
}
}
- override val callSite: String = pysparkErrorContext.map(
- _._2).getOrElse(stackTrace.tail.mkString("\n"))
+ override val callSite: String =
+ pysparkErrorContext.map(_._2).getOrElse(stackTrace.tail.mkString("\n"))
override lazy val summary: String = {
val builder = new StringBuilder
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
index 3e0ebdd627b63..33fa17433abbd 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala
@@ -23,9 +23,9 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.util.ArrayImplicits._
/**
- * Contexts of TreeNodes, including location, SQL text, object type and object name.
- * The only supported object type is "VIEW" now. In the future, we may support SQL UDF or other
- * objects which contain SQL text.
+ * Contexts of TreeNodes, including location, SQL text, object type and object name. The only
+ * supported object type is "VIEW" now. In the future, we may support SQL UDF or other objects
+ * which contain SQL text.
*/
case class Origin(
line: Option[Int] = None,
@@ -41,8 +41,7 @@ case class Origin(
lazy val context: QueryContext = if (stackTrace.isDefined) {
DataFrameQueryContext(stackTrace.get.toImmutableArraySeq, pysparkErrorContext)
} else {
- SQLQueryContext(
- line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName)
+ SQLQueryContext(line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName)
}
def getQueryContext: Array[QueryContext] = {
@@ -61,7 +60,7 @@ trait WithOrigin {
}
/**
- * Provides a location for TreeNodes to ask about the context of their origin. For example, which
+ * Provides a location for TreeNodes to ask about the context of their origin. For example, which
* line of code is currently being parsed.
*/
object CurrentOrigin {
@@ -75,8 +74,7 @@ object CurrentOrigin {
def reset(): Unit = value.set(Origin())
def setPosition(line: Int, start: Int): Unit = {
- value.set(
- value.get.copy(line = Some(line), startPosition = Some(start)))
+ value.set(value.get.copy(line = Some(line), startPosition = Some(start)))
}
def withOrigin[A](o: Origin)(f: => A): A = {
@@ -84,25 +82,29 @@ object CurrentOrigin {
// this way withOrigin can be recursive
val previous = get
set(o)
- val ret = try f finally { set(previous) }
+ val ret =
+ try f
+ finally { set(previous) }
ret
}
/**
- * This helper function captures the Spark API and its call site in the user code from the current
- * stacktrace.
+ * This helper function captures the Spark API and its call site in the user code from the
+ * current stacktrace.
*
* As adding `withOrigin` explicitly to all Spark API definition would be a huge change,
* `withOrigin` is used only at certain places where all API implementation surely pass through
- * and the current stacktrace is filtered to the point where first Spark API code is invoked from
- * the user code.
+ * and the current stacktrace is filtered to the point where first Spark API code is invoked
+ * from the user code.
*
* As there might be multiple nested `withOrigin` calls (e.g. any Spark API implementations can
* invoke other APIs) only the first `withOrigin` is captured because that is closer to the user
* code.
*
- * @param f The function that can use the origin.
- * @return The result of `f`.
+ * @param f
+ * The function that can use the origin.
+ * @return
+ * The result of `f`.
*/
private[sql] def withOrigin[T](f: => T): T = {
if (CurrentOrigin.get.stackTrace.isDefined) {
@@ -114,9 +116,9 @@ object CurrentOrigin {
while (i < st.length && !sparkCode(st(i))) i += 1
// Stop at the end of the first Spark code traces
while (i < st.length && sparkCode(st(i))) i += 1
- val origin = Origin(stackTrace = Some(st.slice(
- from = i - 1,
- until = i + SqlApiConf.get.stackTracesInDataFrameContext)),
+ val origin = Origin(
+ stackTrace =
+ Some(st.slice(from = i - 1, until = i + SqlApiConf.get.stackTracesInDataFrameContext)),
pysparkErrorContext = PySparkCurrentOrigin.get())
CurrentOrigin.withOrigin(origin)(f)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala
index e47ab1978d0ed..533b09e82df13 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/AttributeNameParser.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.errors.DataTypeErrors
trait AttributeNameParser {
+
/**
- * Used to split attribute name by dot with backticks rule.
- * Backticks must appear in pairs, and the quoted string must be a complete name part,
- * which means `ab..c`e.f is not allowed.
- * We can use backtick only inside quoted name parts.
+ * Used to split attribute name by dot with backticks rule. Backticks must appear in pairs, and
+ * the quoted string must be a complete name part, which means `ab..c`e.f is not allowed. We can
+ * use backtick only inside quoted name parts.
*/
def parseAttributeName(name: String): Seq[String] = {
def e = DataTypeErrors.attributeNameSyntaxError(name)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala
index 640304efce4b4..9e90feeb782d6 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/CaseInsensitiveMap.scala
@@ -23,12 +23,12 @@ import java.util.Locale
* Builds a map in which keys are case insensitive. Input map can be accessed for cases where
* case-sensitive information is required. The primary constructor is marked private to avoid
* nested case-insensitive map creation, otherwise the keys in the original map will become
- * case-insensitive in this scenario.
- * Note: CaseInsensitiveMap is serializable. However, after transformation, e.g. `filterKeys()`,
- * it may become not serializable.
+ * case-insensitive in this scenario. Note: CaseInsensitiveMap is serializable. However, after
+ * transformation, e.g. `filterKeys()`, it may become not serializable.
*/
-class CaseInsensitiveMap[T] private (val originalMap: Map[String, T]) extends Map[String, T]
- with Serializable {
+class CaseInsensitiveMap[T] private (val originalMap: Map[String, T])
+ extends Map[String, T]
+ with Serializable {
val keyLowerCasedMap = originalMap.map(kv => kv.copy(_1 = kv._1.toLowerCase(Locale.ROOT)))
@@ -62,4 +62,3 @@ object CaseInsensitiveMap {
case _ => new CaseInsensitiveMap(params)
}
}
-
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala
index e75429c58cc7b..b8ab633b2047f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DataTypeJsonUtils.scala
@@ -25,15 +25,16 @@ import org.json4s.jackson.{JValueDeserializer, JValueSerializer}
import org.apache.spark.sql.types.DataType
object DataTypeJsonUtils {
+
/**
* Jackson serializer for [[DataType]]. Internally this delegates to json4s based serialization.
*/
class DataTypeJsonSerializer extends JsonSerializer[DataType] {
private val delegate = new JValueSerializer
override def serialize(
- value: DataType,
- gen: JsonGenerator,
- provider: SerializerProvider): Unit = {
+ value: DataType,
+ gen: JsonGenerator,
+ provider: SerializerProvider): Unit = {
delegate.serialize(value.jsonValue, gen, provider)
}
}
@@ -46,8 +47,8 @@ object DataTypeJsonUtils {
private val delegate = new JValueDeserializer(classOf[Any])
override def deserialize(
- jsonParser: JsonParser,
- deserializationContext: DeserializationContext): DataType = {
+ jsonParser: JsonParser,
+ deserializationContext: DeserializationContext): DataType = {
val json = delegate.deserialize(jsonParser, deserializationContext)
DataType.parseDataType(json.asInstanceOf[JValue])
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala
index 34d19bb67b71a..5eada9a7be670 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala
@@ -43,7 +43,8 @@ class Iso8601DateFormatter(
locale: Locale,
legacyFormat: LegacyDateFormats.LegacyDateFormat,
isParsing: Boolean)
- extends DateFormatter with DateTimeFormatterHelper {
+ extends DateFormatter
+ with DateTimeFormatterHelper {
@transient
private lazy val formatter = getOrCreateFormatter(pattern, locale, isParsing)
@@ -62,8 +63,7 @@ class Iso8601DateFormatter(
override def format(localDate: LocalDate): String = {
try {
localDate.format(formatter)
- } catch checkFormattedDiff(toJavaDate(localDateToDays(localDate)),
- (d: Date) => format(d))
+ } catch checkFormattedDiff(toJavaDate(localDateToDays(localDate)), (d: Date) => format(d))
}
override def format(days: Int): String = {
@@ -83,19 +83,22 @@ class Iso8601DateFormatter(
}
/**
- * The formatter for dates which doesn't require users to specify a pattern. While formatting,
- * it uses the default pattern [[DateFormatter.defaultPattern]]. In parsing, it follows the CAST
+ * The formatter for dates which doesn't require users to specify a pattern. While formatting, it
+ * uses the default pattern [[DateFormatter.defaultPattern]]. In parsing, it follows the CAST
* logic in conversion of strings to Catalyst's DateType.
*
- * @param locale The locale overrides the system locale and is used in formatting.
- * @param legacyFormat Defines the formatter used for legacy dates.
- * @param isParsing Whether the formatter is used for parsing (`true`) or for formatting (`false`).
+ * @param locale
+ * The locale overrides the system locale and is used in formatting.
+ * @param legacyFormat
+ * Defines the formatter used for legacy dates.
+ * @param isParsing
+ * Whether the formatter is used for parsing (`true`) or for formatting (`false`).
*/
class DefaultDateFormatter(
locale: Locale,
legacyFormat: LegacyDateFormats.LegacyDateFormat,
isParsing: Boolean)
- extends Iso8601DateFormatter(DateFormatter.defaultPattern, locale, legacyFormat, isParsing) {
+ extends Iso8601DateFormatter(DateFormatter.defaultPattern, locale, legacyFormat, isParsing) {
override def parse(s: String): Int = {
try {
@@ -125,11 +128,13 @@ trait LegacyDateFormatter extends DateFormatter {
* JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions.
*
* Note: Using of the default JVM time zone makes the formatter compatible with the legacy
- * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default
- * JVM time zone too.
+ * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default JVM
+ * time zone too.
*
- * @param pattern `java.text.SimpleDateFormat` compatible pattern.
- * @param locale The locale overrides the system locale and is used in parsing/formatting.
+ * @param pattern
+ * `java.text.SimpleDateFormat` compatible pattern.
+ * @param locale
+ * The locale overrides the system locale and is used in parsing/formatting.
*/
class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter {
@transient
@@ -145,14 +150,16 @@ class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDat
* JVM time zone intentionally for compatibility with Spark 2.4 and earlier versions.
*
* Note: Using of the default JVM time zone makes the formatter compatible with the legacy
- * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default
- * JVM time zone too.
+ * `SparkDateTimeUtils` methods `toJavaDate` and `fromJavaDate` that are based on the default JVM
+ * time zone too.
*
- * @param pattern The pattern describing the date and time format.
- * See
- * Date and Time Patterns
- * @param locale The locale whose date format symbols should be used. It overrides the system
- * locale in parsing/formatting.
+ * @param pattern
+ * The pattern describing the date and time format. See Date and
+ * Time Patterns
+ * @param locale
+ * The locale whose date format symbols should be used. It overrides the system locale in
+ * parsing/formatting.
*/
// scalastyle:on line.size.limit
class LegacySimpleDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
index 067e58893126c..71777906f868e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
@@ -40,13 +40,18 @@ trait DateTimeFormatterHelper {
}
private def verifyLocalDate(
- accessor: TemporalAccessor, field: ChronoField, candidate: LocalDate): Unit = {
+ accessor: TemporalAccessor,
+ field: ChronoField,
+ candidate: LocalDate): Unit = {
if (accessor.isSupported(field)) {
val actual = accessor.get(field)
val expected = candidate.get(field)
if (actual != expected) {
throw ExecutionErrors.fieldDiffersFromDerivedLocalDateError(
- field, actual, expected, candidate)
+ field,
+ actual,
+ expected,
+ candidate)
}
}
}
@@ -133,7 +138,8 @@ trait DateTimeFormatterHelper {
// SparkUpgradeException. On the contrary, if the legacy policy set to CORRECTED,
// DateTimeParseException will address by the caller side.
protected def checkParsedDiff[T](
- s: String, legacyParseFunc: String => T): PartialFunction[Throwable, T] = {
+ s: String,
+ legacyParseFunc: String => T): PartialFunction[Throwable, T] = {
case e if needConvertToSparkUpgradeException(e) =>
try {
legacyParseFunc(s)
@@ -151,11 +157,12 @@ trait DateTimeFormatterHelper {
d: T,
legacyFormatFunc: T => String): PartialFunction[Throwable, String] = {
case e if needConvertToSparkUpgradeException(e) =>
- val resultCandidate = try {
- legacyFormatFunc(d)
- } catch {
- case _: Throwable => throw e
- }
+ val resultCandidate =
+ try {
+ legacyFormatFunc(d)
+ } catch {
+ case _: Throwable => throw e
+ }
throw ExecutionErrors.failToParseDateTimeInNewParserError(resultCandidate, e)
}
@@ -166,9 +173,11 @@ trait DateTimeFormatterHelper {
* policy or follow our guide to correct their pattern. Otherwise, the original
* IllegalArgumentException will be thrown.
*
- * @param pattern the date time pattern
- * @param tryLegacyFormatter a func to capture exception, identically which forces a legacy
- * datetime formatter to be initialized
+ * @param pattern
+ * the date time pattern
+ * @param tryLegacyFormatter
+ * a func to capture exception, identically which forces a legacy datetime formatter to be
+ * initialized
*/
protected def checkLegacyFormatter(
pattern: String,
@@ -214,8 +223,7 @@ private object DateTimeFormatterHelper {
/**
* Building a formatter for parsing seconds fraction with variable length
*/
- def createBuilderWithVarLengthSecondFraction(
- pattern: String): DateTimeFormatterBuilder = {
+ def createBuilderWithVarLengthSecondFraction(pattern: String): DateTimeFormatterBuilder = {
val builder = createBuilder()
pattern.split("'").zipWithIndex.foreach {
// Split string starting with the regex itself which is `'` here will produce an extra empty
@@ -229,12 +237,14 @@ private object DateTimeFormatterHelper {
case extractor(prefix, secondFraction, suffix) =>
builder.appendPattern(prefix)
if (secondFraction.nonEmpty) {
- builder.appendFraction(ChronoField.NANO_OF_SECOND, 1, secondFraction.length, false)
+ builder
+ .appendFraction(ChronoField.NANO_OF_SECOND, 1, secondFraction.length, false)
}
rest = suffix
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_DATETIME_PATTERN",
- messageParameters = Map("pattern" -> pattern))
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_DATETIME_PATTERN.SECONDS_FRACTION",
+ messageParameters = Map("pattern" -> pattern))
}
}
case (patternPart, _) => builder.appendLiteral(patternPart)
@@ -258,8 +268,10 @@ private object DateTimeFormatterHelper {
val builder = createBuilder()
.append(DateTimeFormatter.ISO_LOCAL_DATE)
.appendLiteral(' ')
- .appendValue(ChronoField.HOUR_OF_DAY, 2).appendLiteral(':')
- .appendValue(ChronoField.MINUTE_OF_HOUR, 2).appendLiteral(':')
+ .appendValue(ChronoField.HOUR_OF_DAY, 2)
+ .appendLiteral(':')
+ .appendValue(ChronoField.MINUTE_OF_HOUR, 2)
+ .appendLiteral(':')
.appendValue(ChronoField.SECOND_OF_MINUTE, 2)
.appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true)
toFormatter(builder, TimestampFormatter.defaultLocale)
@@ -299,17 +311,21 @@ private object DateTimeFormatterHelper {
* parsing/formatting datetime values. The pattern string is incompatible with the one defined
* by SimpleDateFormat in Spark 2.4 and earlier. This function converts all incompatible pattern
* for the new parser in Spark 3.0. See more details in SPARK-31030.
- * @param pattern The input pattern.
- * @return The pattern for new parser
+ * @param pattern
+ * The input pattern.
+ * @return
+ * The pattern for new parser
*/
def convertIncompatiblePattern(pattern: String, isParsing: Boolean): String = {
- val eraDesignatorContained = pattern.split("'").zipWithIndex.exists {
- case (patternPart, index) =>
+ val eraDesignatorContained =
+ pattern.split("'").zipWithIndex.exists { case (patternPart, index) =>
// Text can be quoted using single quotes, we only check the non-quote parts.
index % 2 == 0 && patternPart.contains("G")
- }
- (pattern + " ").split("'").zipWithIndex.map {
- case (patternPart, index) =>
+ }
+ (pattern + " ")
+ .split("'")
+ .zipWithIndex
+ .map { case (patternPart, index) =>
if (index % 2 == 0) {
for (c <- patternPart if weekBasedLetters.contains(c)) {
throw new SparkIllegalArgumentException(
@@ -317,12 +333,10 @@ private object DateTimeFormatterHelper {
messageParameters = Map("c" -> c.toString))
}
for (c <- patternPart if unsupportedLetters.contains(c) ||
- (isParsing && unsupportedLettersForParsing.contains(c))) {
+ (isParsing && unsupportedLettersForParsing.contains(c))) {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_DATETIME_PATTERN.ILLEGAL_CHARACTER",
- messageParameters = Map(
- "c" -> c.toString,
- "pattern" -> pattern))
+ messageParameters = Map("c" -> c.toString, "pattern" -> pattern))
}
for (style <- unsupportedPatternLengths if patternPart.contains(style)) {
throw new SparkIllegalArgumentException(
@@ -340,6 +354,8 @@ private object DateTimeFormatterHelper {
} else {
patternPart
}
- }.mkString("'").stripSuffix(" ")
+ }
+ .mkString("'")
+ .stripSuffix(" ")
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index 96c3fb81aa66f..b113bccc74dfb 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -89,10 +89,7 @@ object MathUtils {
def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))
- def withOverflow[A](
- f: => A,
- hint: String = "",
- context: QueryContext = null): A = {
+ def withOverflow[A](f: => A, hint: String = "", context: QueryContext = null): A = {
try {
f
} catch {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala
index f9566b0e1fb13..9c043320dc812 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala
@@ -34,8 +34,8 @@ import org.apache.spark.util.SparkClassUtils
/**
* The collection of functions for rebasing days and microseconds from/to the hybrid calendar
- * (Julian + Gregorian since 1582-10-15) which is used by Spark 2.4 and earlier versions
- * to/from Proleptic Gregorian calendar which is used by Spark since version 3.0. See SPARK-26651.
+ * (Julian + Gregorian since 1582-10-15) which is used by Spark 2.4 and earlier versions to/from
+ * Proleptic Gregorian calendar which is used by Spark since version 3.0. See SPARK-26651.
*/
object RebaseDateTime {
@@ -46,20 +46,22 @@ object RebaseDateTime {
}
/**
- * Rebases days since the epoch from an original to an target calendar, for instance,
- * from a hybrid (Julian + Gregorian) to Proleptic Gregorian calendar.
+ * Rebases days since the epoch from an original to an target calendar, for instance, from a
+ * hybrid (Julian + Gregorian) to Proleptic Gregorian calendar.
*
* It finds the latest switch day which is less than the given `days`, and adds the difference
- * in days associated with the switch days to the given `days`.
- * The function is based on linear search which starts from the most recent switch days.
- * This allows to perform less comparisons for modern dates.
+ * in days associated with the switch days to the given `days`. The function is based on linear
+ * search which starts from the most recent switch days. This allows to perform less comparisons
+ * for modern dates.
*
- * @param switches The days when difference in days between original and target calendar
- * was changed.
- * @param diffs The differences in days between calendars.
- * @param days The number of days since the epoch 1970-01-01 to be rebased
- * to the target calendar.
- * @return The rebased days.
+ * @param switches
+ * The days when difference in days between original and target calendar was changed.
+ * @param diffs
+ * The differences in days between calendars.
+ * @param days
+ * The number of days since the epoch 1970-01-01 to be rebased to the target calendar.
+ * @return
+ * The rebased days.
*/
private def rebaseDays(switches: Array[Int], diffs: Array[Int], days: Int): Int = {
var i = switches.length
@@ -77,9 +79,8 @@ object RebaseDateTime {
// Julian calendar). This array is not applicable for dates before the staring point.
// Rebasing switch days and diffs `julianGregDiffSwitchDay` and `julianGregDiffs`
// was generated by the `localRebaseJulianToGregorianDays` function.
- private val julianGregDiffSwitchDay = Array(
- -719164, -682945, -646420, -609895, -536845, -500320, -463795,
- -390745, -354220, -317695, -244645, -208120, -171595, -141427)
+ private val julianGregDiffSwitchDay = Array(-719164, -682945, -646420, -609895, -536845,
+ -500320, -463795, -390745, -354220, -317695, -244645, -208120, -171595, -141427)
final val lastSwitchJulianDay: Int = julianGregDiffSwitchDay.last
@@ -88,20 +89,20 @@ object RebaseDateTime {
/**
* Converts the given number of days since the epoch day 1970-01-01 to a local date in Julian
- * calendar, interprets the result as a local date in Proleptic Gregorian calendar, and takes the
- * number of days since the epoch from the Gregorian local date.
+ * calendar, interprets the result as a local date in Proleptic Gregorian calendar, and takes
+ * the number of days since the epoch from the Gregorian local date.
*
* This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use
- * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian
- * calendar. See SPARK-26651.
+ * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic
+ * Gregorian calendar. See SPARK-26651.
*
- * For example:
- * Julian calendar: 1582-01-01 -> -141704
- * Proleptic Gregorian calendar: 1582-01-01 -> -141714
- * The code below converts -141704 to -141714.
+ * For example: Julian calendar: 1582-01-01 -> -141704 Proleptic Gregorian calendar: 1582-01-01
+ * -> -141714 The code below converts -141704 to -141714.
*
- * @param days The number of days since the epoch in Julian calendar. It can be negative.
- * @return The rebased number of days in Gregorian calendar.
+ * @param days
+ * The number of days since the epoch in Julian calendar. It can be negative.
+ * @return
+ * The rebased number of days in Gregorian calendar.
*/
private[sql] def localRebaseJulianToGregorianDays(days: Int): Int = {
val utcCal = new Calendar.Builder()
@@ -111,14 +112,15 @@ object RebaseDateTime {
.setTimeZone(TimeZoneUTC)
.setInstant(Math.multiplyExact(days, MILLIS_PER_DAY))
.build()
- val localDate = LocalDate.of(
- utcCal.get(YEAR),
- utcCal.get(MONTH) + 1,
- // The number of days will be added later to handle non-existing
- // Julian dates in Proleptic Gregorian calendar.
- // For example, 1000-02-29 exists in Julian calendar because 1000
- // is a leap year but it is not a leap year in Gregorian calendar.
- 1)
+ val localDate = LocalDate
+ .of(
+ utcCal.get(YEAR),
+ utcCal.get(MONTH) + 1,
+ // The number of days will be added later to handle non-existing
+ // Julian dates in Proleptic Gregorian calendar.
+ // For example, 1000-02-29 exists in Julian calendar because 1000
+ // is a leap year but it is not a leap year in Gregorian calendar.
+ 1)
.`with`(ChronoField.ERA, utcCal.get(ERA))
.plusDays(utcCal.get(DAY_OF_MONTH) - 1)
Math.toIntExact(localDate.toEpochDay)
@@ -129,8 +131,10 @@ object RebaseDateTime {
* pre-calculated rebasing array to save calculation. For dates of Before Common Era, the
* function falls back to the regular unoptimized version.
*
- * @param days The number of days since the epoch in Julian calendar. It can be negative.
- * @return The rebased number of days in Gregorian calendar.
+ * @param days
+ * The number of days since the epoch in Julian calendar. It can be negative.
+ * @return
+ * The rebased number of days in Gregorian calendar.
*/
def rebaseJulianToGregorianDays(days: Int): Int = {
if (days < julianCommonEraStartDay) {
@@ -143,18 +147,17 @@ object RebaseDateTime {
// The differences in days between Proleptic Gregorian and Julian dates.
// The diff at the index `i` is applicable for all days in the date interval:
// [gregJulianDiffSwitchDay(i), gregJulianDiffSwitchDay(i+1))
- private val gregJulianDiffs = Array(
- -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
+ private val gregJulianDiffs =
+ Array(-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
// The sorted days in Proleptic Gregorian calendar when difference in days between
// Proleptic Gregorian and Julian was changed.
// The starting point is the `0001-01-01` (-719162 days since the epoch in
// Proleptic Gregorian calendar). This array is not applicable for dates before the staring point.
// Rebasing switch days and diffs `gregJulianDiffSwitchDay` and `gregJulianDiffs`
// was generated by the `localRebaseGregorianToJulianDays` function.
- private val gregJulianDiffSwitchDay = Array(
- -719162, -682944, -646420, -609896, -536847, -500323, -463799, -390750,
- -354226, -317702, -244653, -208129, -171605, -141436, -141435, -141434,
- -141433, -141432, -141431, -141430, -141429, -141428, -141427)
+ private val gregJulianDiffSwitchDay = Array(-719162, -682944, -646420, -609896, -536847,
+ -500323, -463799, -390750, -354226, -317702, -244653, -208129, -171605, -141436, -141435,
+ -141434, -141433, -141432, -141431, -141430, -141429, -141428, -141427)
final val lastSwitchGregorianDay: Int = gregJulianDiffSwitchDay.last
@@ -171,17 +174,16 @@ object RebaseDateTime {
* number of days since the epoch from the Julian local date.
*
* This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use
- * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian
- * calendar. See SPARK-26651.
+ * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic
+ * Gregorian calendar. See SPARK-26651.
*
- * For example:
- * Proleptic Gregorian calendar: 1582-01-01 -> -141714
- * Julian calendar: 1582-01-01 -> -141704
- * The code below converts -141714 to -141704.
+ * For example: Proleptic Gregorian calendar: 1582-01-01 -> -141714 Julian calendar: 1582-01-01
+ * -> -141704 The code below converts -141714 to -141704.
*
- * @param days The number of days since the epoch in Proleptic Gregorian calendar.
- * It can be negative.
- * @return The rebased number of days in Julian calendar.
+ * @param days
+ * The number of days since the epoch in Proleptic Gregorian calendar. It can be negative.
+ * @return
+ * The rebased number of days in Julian calendar.
*/
private[sql] def localRebaseGregorianToJulianDays(days: Int): Int = {
var localDate = LocalDate.ofEpochDay(days)
@@ -204,8 +206,10 @@ object RebaseDateTime {
* pre-calculated rebasing array to save calculation. For dates of Before Common Era, the
* function falls back to the regular unoptimized version.
*
- * @param days The number of days since the epoch in Gregorian calendar. It can be negative.
- * @return The rebased number of days since the epoch in Julian calendar.
+ * @param days
+ * The number of days since the epoch in Gregorian calendar. It can be negative.
+ * @return
+ * The rebased number of days since the epoch in Julian calendar.
*/
def rebaseGregorianToJulianDays(days: Int): Int = {
if (days < gregorianCommonEraStartDay) {
@@ -215,10 +219,9 @@ object RebaseDateTime {
}
}
-
/**
- * The class describes JSON records with microseconds rebasing info.
- * Here is an example of JSON file:
+ * The class describes JSON records with microseconds rebasing info. Here is an example of JSON
+ * file:
* {{{
* [
* {
@@ -229,37 +232,44 @@ object RebaseDateTime {
* ]
* }}}
*
- * @param tz One of time zone ID which is expected to be acceptable by `ZoneId.of`.
- * @param switches An ordered array of seconds since the epoch when the diff between
- * two calendars are changed.
- * @param diffs Differences in seconds associated with elements of `switches`.
+ * @param tz
+ * One of time zone ID which is expected to be acceptable by `ZoneId.of`.
+ * @param switches
+ * An ordered array of seconds since the epoch when the diff between two calendars are
+ * changed.
+ * @param diffs
+ * Differences in seconds associated with elements of `switches`.
*/
private case class JsonRebaseRecord(tz: String, switches: Array[Long], diffs: Array[Long])
/**
* Rebasing info used to convert microseconds from an original to a target calendar.
*
- * @param switches An ordered array of microseconds since the epoch when the diff between
- * two calendars are changed.
- * @param diffs Differences in microseconds associated with elements of `switches`.
+ * @param switches
+ * An ordered array of microseconds since the epoch when the diff between two calendars are
+ * changed.
+ * @param diffs
+ * Differences in microseconds associated with elements of `switches`.
*/
private[sql] case class RebaseInfo(switches: Array[Long], diffs: Array[Long])
/**
- * Rebases micros since the epoch from an original to an target calendar, for instance,
- * from a hybrid (Julian + Gregorian) to Proleptic Gregorian calendar.
+ * Rebases micros since the epoch from an original to an target calendar, for instance, from a
+ * hybrid (Julian + Gregorian) to Proleptic Gregorian calendar.
*
* It finds the latest switch micros which is less than the given `micros`, and adds the
- * difference in micros associated with the switch micros to the given `micros`.
- * The function is based on linear search which starts from the most recent switch micros.
- * This allows to perform less comparisons for modern timestamps.
+ * difference in micros associated with the switch micros to the given `micros`. The function is
+ * based on linear search which starts from the most recent switch micros. This allows to
+ * perform less comparisons for modern timestamps.
*
- * @param rebaseInfo The rebasing info contains an ordered micros when difference in micros
- * between original and target calendar was changed,
- * and differences in micros between calendars
- * @param micros The number of micros since the epoch 1970-01-01T00:00:00Z to be rebased
- * to the target calendar. It can be negative.
- * @return The rebased micros.
+ * @param rebaseInfo
+ * The rebasing info contains an ordered micros when difference in micros between original and
+ * target calendar was changed, and differences in micros between calendars
+ * @param micros
+ * The number of micros since the epoch 1970-01-01T00:00:00Z to be rebased to the target
+ * calendar. It can be negative.
+ * @return
+ * The rebased micros.
*/
private def rebaseMicros(rebaseInfo: RebaseInfo, micros: Long): Long = {
val switches = rebaseInfo.switches
@@ -296,18 +306,19 @@ object RebaseDateTime {
/**
* A map of time zone IDs to its ordered time points (instants in microseconds since the epoch)
- * when the difference between 2 instances associated with the same local timestamp in
- * Proleptic Gregorian and the hybrid calendar was changed, and to the diff at the index `i` is
- * applicable for all microseconds in the time interval:
- * [gregJulianDiffSwitchMicros(i), gregJulianDiffSwitchMicros(i+1))
+ * when the difference between 2 instances associated with the same local timestamp in Proleptic
+ * Gregorian and the hybrid calendar was changed, and to the diff at the index `i` is applicable
+ * for all microseconds in the time interval: [gregJulianDiffSwitchMicros(i),
+ * gregJulianDiffSwitchMicros(i+1))
*/
private val gregJulianRebaseMap = loadRebaseRecords("gregorian-julian-rebase-micros.json")
private def getLastSwitchTs(rebaseMap: AnyRefMap[String, RebaseInfo]): Long = {
val latestTs = rebaseMap.values.map(_.switches.last).max
- require(rebaseMap.values.forall(_.diffs.last == 0),
+ require(
+ rebaseMap.values.forall(_.diffs.last == 0),
s"Differences between Julian and Gregorian calendar after ${microsToInstant(latestTs)} " +
- "are expected to be zero for all available time zones.")
+ "are expected to be zero for all available time zones.")
latestTs
}
// The switch time point after which all diffs between Gregorian and Julian calendars
@@ -315,29 +326,30 @@ object RebaseDateTime {
final val lastSwitchGregorianTs: Long = getLastSwitchTs(gregJulianRebaseMap)
private final val gregorianStartTs = LocalDateTime.of(gregorianStartDate, LocalTime.MIDNIGHT)
- private final val julianEndTs = LocalDateTime.of(
- julianEndDate,
- LocalTime.of(23, 59, 59, 999999999))
+ private final val julianEndTs =
+ LocalDateTime.of(julianEndDate, LocalTime.of(23, 59, 59, 999999999))
/**
* Converts the given number of microseconds since the epoch '1970-01-01T00:00:00Z', to a local
- * date-time in Proleptic Gregorian calendar with timezone identified by `zoneId`, interprets the
- * result as a local date-time in Julian calendar with the same timezone, and takes microseconds
- * since the epoch from the Julian local date-time.
+ * date-time in Proleptic Gregorian calendar with timezone identified by `zoneId`, interprets
+ * the result as a local date-time in Julian calendar with the same timezone, and takes
+ * microseconds since the epoch from the Julian local date-time.
*
* This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use
- * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian
- * calendar. See SPARK-26651.
+ * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic
+ * Gregorian calendar. See SPARK-26651.
*
- * For example:
- * Proleptic Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544
- * Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544
- * The code below converts -12244061221876544 to -12243196799876544.
+ * For example: Proleptic Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544
+ * Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544 The code below converts
+ * -12244061221876544 to -12243196799876544.
*
- * @param tz The time zone at which the rebasing should be performed.
- * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z'
- * in Proleptic Gregorian calendar. It can be negative.
- * @return The rebased microseconds since the epoch in Julian calendar.
+ * @param tz
+ * The time zone at which the rebasing should be performed.
+ * @param micros
+ * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in Proleptic Gregorian
+ * calendar. It can be negative.
+ * @return
+ * The rebased microseconds since the epoch in Julian calendar.
*/
private[sql] def rebaseGregorianToJulianMicros(tz: TimeZone, micros: Long): Long = {
val instant = microsToInstant(micros)
@@ -380,10 +392,13 @@ object RebaseDateTime {
* contain information about the given time zone `timeZoneId` or `micros` is related to Before
* Common Era, the function falls back to the regular unoptimized version.
*
- * @param timeZoneId A string identifier of a time zone.
- * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z'
- * in Proleptic Gregorian calendar. It can be negative.
- * @return The rebased microseconds since the epoch in Julian calendar.
+ * @param timeZoneId
+ * A string identifier of a time zone.
+ * @param micros
+ * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in Proleptic Gregorian
+ * calendar. It can be negative.
+ * @return
+ * The rebased microseconds since the epoch in Julian calendar.
*/
def rebaseGregorianToJulianMicros(timeZoneId: String, micros: Long): Long = {
if (micros >= lastSwitchGregorianTs) {
@@ -404,12 +419,14 @@ object RebaseDateTime {
* contain information about the current JVM system time zone or `micros` is related to Before
* Common Era, the function falls back to the regular unoptimized version.
*
- * Note: The function assumes that the input micros was derived from a local timestamp
- * at the default system JVM time zone in Proleptic Gregorian calendar.
+ * Note: The function assumes that the input micros was derived from a local timestamp at the
+ * default system JVM time zone in Proleptic Gregorian calendar.
*
- * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z'
- * in Proleptic Gregorian calendar. It can be negative.
- * @return The rebased microseconds since the epoch in Julian calendar.
+ * @param micros
+ * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in Proleptic Gregorian
+ * calendar. It can be negative.
+ * @return
+ * The rebased microseconds since the epoch in Julian calendar.
*/
def rebaseGregorianToJulianMicros(micros: Long): Long = {
rebaseGregorianToJulianMicros(TimeZone.getDefault.getID, micros)
@@ -418,22 +435,24 @@ object RebaseDateTime {
/**
* Converts the given number of microseconds since the epoch '1970-01-01T00:00:00Z', to a local
* date-time in Julian calendar with timezone identified by `zoneId`, interprets the result as a
- * local date-time in Proleptic Gregorian calendar with the same timezone, and takes microseconds
- * since the epoch from the Gregorian local date-time.
+ * local date-time in Proleptic Gregorian calendar with the same timezone, and takes
+ * microseconds since the epoch from the Gregorian local date-time.
*
* This is used to guarantee backward compatibility, as Spark 2.4 and earlier versions use
- * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic Gregorian
- * calendar. See SPARK-26651.
+ * Julian calendar for dates before 1582-10-15, while Spark 3.0 and later use Proleptic
+ * Gregorian calendar. See SPARK-26651.
*
- * For example:
- * Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544
- * Proleptic Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544
- * The code below converts -12243196799876544 to -12244061221876544.
+ * For example: Julian calendar: 1582-01-01 00:00:00.123456 -> -12243196799876544 Proleptic
+ * Gregorian calendar: 1582-01-01 00:00:00.123456 -> -12244061221876544 The code below converts
+ * -12243196799876544 to -12244061221876544.
*
- * @param tz The time zone at which the rebasing should be performed.
- * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z'
- * in the Julian calendar. It can be negative.
- * @return The rebased microseconds since the epoch in Proleptic Gregorian calendar.
+ * @param tz
+ * The time zone at which the rebasing should be performed.
+ * @param micros
+ * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in the Julian calendar.
+ * It can be negative.
+ * @return
+ * The rebased microseconds since the epoch in Proleptic Gregorian calendar.
*/
private[sql] def rebaseJulianToGregorianMicros(tz: TimeZone, micros: Long): Long = {
val cal = new Calendar.Builder()
@@ -442,18 +461,19 @@ object RebaseDateTime {
.setInstant(microsToMillis(micros))
.setTimeZone(tz)
.build()
- val localDateTime = LocalDateTime.of(
- cal.get(YEAR),
- cal.get(MONTH) + 1,
- // The number of days will be added later to handle non-existing
- // Julian dates in Proleptic Gregorian calendar.
- // For example, 1000-02-29 exists in Julian calendar because 1000
- // is a leap year but it is not a leap year in Gregorian calendar.
- 1,
- cal.get(HOUR_OF_DAY),
- cal.get(MINUTE),
- cal.get(SECOND),
- (Math.floorMod(micros, MICROS_PER_SECOND) * NANOS_PER_MICROS).toInt)
+ val localDateTime = LocalDateTime
+ .of(
+ cal.get(YEAR),
+ cal.get(MONTH) + 1,
+ // The number of days will be added later to handle non-existing
+ // Julian dates in Proleptic Gregorian calendar.
+ // For example, 1000-02-29 exists in Julian calendar because 1000
+ // is a leap year but it is not a leap year in Gregorian calendar.
+ 1,
+ cal.get(HOUR_OF_DAY),
+ cal.get(MINUTE),
+ cal.get(SECOND),
+ (Math.floorMod(micros, MICROS_PER_SECOND) * NANOS_PER_MICROS).toInt)
.`with`(ChronoField.ERA, cal.get(ERA))
.plusDays(cal.get(DAY_OF_MONTH) - 1)
val zoneId = tz.toZoneId
@@ -494,10 +514,13 @@ object RebaseDateTime {
* contain information about the given time zone `timeZoneId` or `micros` is related to Before
* Common Era, the function falls back to the regular unoptimized version.
*
- * @param timeZoneId A string identifier of a time zone.
- * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z'
- * in the Julian calendar. It can be negative.
- * @return The rebased microseconds since the epoch in Proleptic Gregorian calendar.
+ * @param timeZoneId
+ * A string identifier of a time zone.
+ * @param micros
+ * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in the Julian calendar.
+ * It can be negative.
+ * @return
+ * The rebased microseconds since the epoch in Proleptic Gregorian calendar.
*/
def rebaseJulianToGregorianMicros(timeZoneId: String, micros: Long): Long = {
if (micros >= lastSwitchJulianTs) {
@@ -518,12 +541,14 @@ object RebaseDateTime {
* contain information about the current JVM system time zone or `micros` is related to Before
* Common Era, the function falls back to the regular unoptimized version.
*
- * Note: The function assumes that the input micros was derived from a local timestamp
- * at the default system JVM time zone in the Julian calendar.
+ * Note: The function assumes that the input micros was derived from a local timestamp at the
+ * default system JVM time zone in the Julian calendar.
*
- * @param micros The number of microseconds since the epoch '1970-01-01T00:00:00Z'
- * in the Julian calendar. It can be negative.
- * @return The rebased microseconds since the epoch in Proleptic Gregorian calendar.
+ * @param micros
+ * The number of microseconds since the epoch '1970-01-01T00:00:00Z' in the Julian calendar.
+ * It can be negative.
+ * @return
+ * The rebased microseconds since the epoch in Proleptic Gregorian calendar.
*/
def rebaseJulianToGregorianMicros(micros: Long): Long = {
rebaseJulianToGregorianMicros(TimeZone.getDefault.getID, micros)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala
index 498eb83566eb3..2a26c079e8d4d 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala
@@ -21,15 +21,18 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructType, VarcharType}
trait SparkCharVarcharUtils {
+
/**
- * Returns true if the given data type is CharType/VarcharType or has nested CharType/VarcharType.
+ * Returns true if the given data type is CharType/VarcharType or has nested
+ * CharType/VarcharType.
*/
def hasCharVarchar(dt: DataType): Boolean = {
dt.existsRecursively(f => f.isInstanceOf[CharType] || f.isInstanceOf[VarcharType])
}
/**
- * Validate the given [[DataType]] to fail if it is char or varchar types or contains nested ones
+ * Validate the given [[DataType]] to fail if it is char or varchar types or contains nested
+ * ones
*/
def failIfHasCharVarchar(dt: DataType): DataType = {
if (!SqlApiConf.get.charVarcharAsString && hasCharVarchar(dt)) {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala
index a6592ad51c65c..4e94bc6617357 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala
@@ -54,8 +54,10 @@ trait SparkDateTimeUtils {
/**
* Converts an Java object to days.
*
- * @param obj Either an object of `java.sql.Date` or `java.time.LocalDate`.
- * @return The number of days since 1970-01-01.
+ * @param obj
+ * Either an object of `java.sql.Date` or `java.time.LocalDate`.
+ * @return
+ * The number of days since 1970-01-01.
*/
def anyToDays(obj: Any): Int = obj match {
case d: Date => fromJavaDate(d)
@@ -65,8 +67,10 @@ trait SparkDateTimeUtils {
/**
* Converts an Java object to microseconds.
*
- * @param obj Either an object of `java.sql.Timestamp` or `java.time.{Instant,LocalDateTime}`.
- * @return The number of micros since the epoch.
+ * @param obj
+ * Either an object of `java.sql.Timestamp` or `java.time.{Instant,LocalDateTime}`.
+ * @return
+ * The number of micros since the epoch.
*/
def anyToMicros(obj: Any): Long = obj match {
case t: Timestamp => fromJavaTimestamp(t)
@@ -75,8 +79,8 @@ trait SparkDateTimeUtils {
}
/**
- * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have microseconds
- * precision, so this conversion is lossy.
+ * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have
+ * microseconds precision, so this conversion is lossy.
*/
def microsToMillis(micros: Long): Long = {
// When the timestamp is negative i.e before 1970, we need to adjust the milliseconds portion.
@@ -97,8 +101,8 @@ trait SparkDateTimeUtils {
private val MIN_SECONDS = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND)
/**
- * Obtains an instance of `java.time.Instant` using microseconds from
- * the epoch of 1970-01-01 00:00:00Z.
+ * Obtains an instance of `java.time.Instant` using microseconds from the epoch of 1970-01-01
+ * 00:00:00Z.
*/
def microsToInstant(micros: Long): Instant = {
val secs = Math.floorDiv(micros, MICROS_PER_SECOND)
@@ -110,8 +114,8 @@ trait SparkDateTimeUtils {
/**
* Gets the number of microseconds since the epoch of 1970-01-01 00:00:00Z from the given
- * instance of `java.time.Instant`. The epoch microsecond count is a simple incrementing count of
- * microseconds where microsecond 0 is 1970-01-01 00:00:00Z.
+ * instance of `java.time.Instant`. The epoch microsecond count is a simple incrementing count
+ * of microseconds where microsecond 0 is 1970-01-01 00:00:00Z.
*/
def instantToMicros(instant: Instant): Long = {
val secs = instant.getEpochSecond
@@ -127,8 +131,8 @@ trait SparkDateTimeUtils {
/**
* Converts the timestamp `micros` from one timezone to another.
*
- * Time-zone rules, such as daylight savings, mean that not every local date-time
- * is valid for the `toZone` time zone, thus the local date-time may be adjusted.
+ * Time-zone rules, such as daylight savings, mean that not every local date-time is valid for
+ * the `toZone` time zone, thus the local date-time may be adjusted.
*/
def convertTz(micros: Long, fromZone: ZoneId, toZone: ZoneId): Long = {
val rebasedDateTime = getLocalDateTime(micros, toZone).atZone(fromZone)
@@ -160,14 +164,16 @@ trait SparkDateTimeUtils {
def daysToLocalDate(days: Int): LocalDate = LocalDate.ofEpochDay(days)
/**
- * Converts microseconds since 1970-01-01 00:00:00Z to days since 1970-01-01 at the given zone ID.
+ * Converts microseconds since 1970-01-01 00:00:00Z to days since 1970-01-01 at the given zone
+ * ID.
*/
def microsToDays(micros: Long, zoneId: ZoneId): Int = {
localDateToDays(getLocalDateTime(micros, zoneId).toLocalDate)
}
/**
- * Converts days since 1970-01-01 at the given zone ID to microseconds since 1970-01-01 00:00:00Z.
+ * Converts days since 1970-01-01 at the given zone ID to microseconds since 1970-01-01
+ * 00:00:00Z.
*/
def daysToMicros(days: Int, zoneId: ZoneId): Long = {
val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant
@@ -175,20 +181,22 @@ trait SparkDateTimeUtils {
}
/**
- * Converts a local date at the default JVM time zone to the number of days since 1970-01-01
- * in the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are
+ * Converts a local date at the default JVM time zone to the number of days since 1970-01-01 in
+ * the hybrid calendar (Julian + Gregorian) by discarding the time part. The resulted days are
* rebased from the hybrid to Proleptic Gregorian calendar. The days rebasing is performed via
- * UTC time zone for simplicity because the difference between two calendars is the same in
- * any given time zone and UTC time zone.
+ * UTC time zone for simplicity because the difference between two calendars is the same in any
+ * given time zone and UTC time zone.
*
- * Note: The date is shifted by the offset of the default JVM time zone for backward compatibility
- * with Spark 2.4 and earlier versions. The goal of the shift is to get a local date derived
- * from the number of days that has the same date fields (year, month, day) as the original
- * `date` at the default JVM time zone.
+ * Note: The date is shifted by the offset of the default JVM time zone for backward
+ * compatibility with Spark 2.4 and earlier versions. The goal of the shift is to get a local
+ * date derived from the number of days that has the same date fields (year, month, day) as the
+ * original `date` at the default JVM time zone.
*
- * @param date It represents a specific instant in time based on the hybrid calendar which
- * combines Julian and Gregorian calendars.
- * @return The number of days since the epoch in Proleptic Gregorian calendar.
+ * @param date
+ * It represents a specific instant in time based on the hybrid calendar which combines Julian
+ * and Gregorian calendars.
+ * @return
+ * The number of days since the epoch in Proleptic Gregorian calendar.
*/
def fromJavaDate(date: Date): Int = {
val millisUtc = date.getTime
@@ -207,18 +215,20 @@ trait SparkDateTimeUtils {
}
/**
- * Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar to a local date
- * at the default JVM time zone in the hybrid calendar (Julian + Gregorian). It rebases the given
+ * Converts days since the epoch 1970-01-01 in Proleptic Gregorian calendar to a local date at
+ * the default JVM time zone in the hybrid calendar (Julian + Gregorian). It rebases the given
* days from Proleptic Gregorian to the hybrid calendar at UTC time zone for simplicity because
* the difference between two calendars doesn't depend on any time zone. The result is shifted
- * by the time zone offset in wall clock to have the same date fields (year, month, day)
- * at the default JVM time zone as the input `daysSinceEpoch` in Proleptic Gregorian calendar.
+ * by the time zone offset in wall clock to have the same date fields (year, month, day) at the
+ * default JVM time zone as the input `daysSinceEpoch` in Proleptic Gregorian calendar.
*
- * Note: The date is shifted by the offset of the default JVM time zone for backward compatibility
- * with Spark 2.4 and earlier versions.
+ * Note: The date is shifted by the offset of the default JVM time zone for backward
+ * compatibility with Spark 2.4 and earlier versions.
*
- * @param days The number of days since 1970-01-01 in Proleptic Gregorian calendar.
- * @return A local date in the hybrid calendar as `java.sql.Date` from number of days since epoch.
+ * @param days
+ * The number of days since 1970-01-01 in Proleptic Gregorian calendar.
+ * @return
+ * A local date in the hybrid calendar as `java.sql.Date` from number of days since epoch.
*/
def toJavaDate(days: Int): Date = {
val rebasedDays = rebaseGregorianToJulianDays(days)
@@ -233,20 +243,22 @@ trait SparkDateTimeUtils {
}
/**
- * Converts microseconds since the epoch to an instance of `java.sql.Timestamp`
- * via creating a local timestamp at the system time zone in Proleptic Gregorian
- * calendar, extracting date and time fields like `year` and `hours`, and forming
- * new timestamp in the hybrid calendar from the extracted fields.
+ * Converts microseconds since the epoch to an instance of `java.sql.Timestamp` via creating a
+ * local timestamp at the system time zone in Proleptic Gregorian calendar, extracting date and
+ * time fields like `year` and `hours`, and forming new timestamp in the hybrid calendar from
+ * the extracted fields.
*
- * The conversion is based on the JVM system time zone because the `java.sql.Timestamp`
- * uses the time zone internally.
+ * The conversion is based on the JVM system time zone because the `java.sql.Timestamp` uses the
+ * time zone internally.
*
* The method performs the conversion via local timestamp fields to have the same date-time
- * representation as `year`, `month`, `day`, ..., `seconds` in the original calendar
- * and in the target calendar.
+ * representation as `year`, `month`, `day`, ..., `seconds` in the original calendar and in the
+ * target calendar.
*
- * @param micros The number of microseconds since 1970-01-01T00:00:00.000000Z.
- * @return A `java.sql.Timestamp` from number of micros since epoch.
+ * @param micros
+ * The number of microseconds since 1970-01-01T00:00:00.000000Z.
+ * @return
+ * A `java.sql.Timestamp` from number of micros since epoch.
*/
def toJavaTimestamp(micros: Long): Timestamp =
toJavaTimestampNoRebase(rebaseGregorianToJulianMicros(micros))
@@ -257,8 +269,10 @@ trait SparkDateTimeUtils {
/**
* Converts microseconds since the epoch to an instance of `java.sql.Timestamp`.
*
- * @param micros The number of microseconds since 1970-01-01T00:00:00.000000Z.
- * @return A `java.sql.Timestamp` from number of micros since epoch.
+ * @param micros
+ * The number of microseconds since 1970-01-01T00:00:00.000000Z.
+ * @return
+ * A `java.sql.Timestamp` from number of micros since epoch.
*/
def toJavaTimestampNoRebase(micros: Long): Timestamp = {
val seconds = Math.floorDiv(micros, MICROS_PER_SECOND)
@@ -270,22 +284,22 @@ trait SparkDateTimeUtils {
/**
* Converts an instance of `java.sql.Timestamp` to the number of microseconds since
- * 1970-01-01T00:00:00.000000Z. It extracts date-time fields from the input, builds
- * a local timestamp in Proleptic Gregorian calendar from the fields, and binds
- * the timestamp to the system time zone. The resulted instant is converted to
- * microseconds since the epoch.
+ * 1970-01-01T00:00:00.000000Z. It extracts date-time fields from the input, builds a local
+ * timestamp in Proleptic Gregorian calendar from the fields, and binds the timestamp to the
+ * system time zone. The resulted instant is converted to microseconds since the epoch.
*
- * The conversion is performed via the system time zone because it is used internally
- * in `java.sql.Timestamp` while extracting date-time fields.
+ * The conversion is performed via the system time zone because it is used internally in
+ * `java.sql.Timestamp` while extracting date-time fields.
*
* The goal of the function is to have the same local date-time in the original calendar
- * - the hybrid calendar (Julian + Gregorian) and in the target calendar which is
- * Proleptic Gregorian calendar, see SPARK-26651.
+ * - the hybrid calendar (Julian + Gregorian) and in the target calendar which is Proleptic
+ * Gregorian calendar, see SPARK-26651.
*
- * @param t It represents a specific instant in time based on
- * the hybrid calendar which combines Julian and
- * Gregorian calendars.
- * @return The number of micros since epoch from `java.sql.Timestamp`.
+ * @param t
+ * It represents a specific instant in time based on the hybrid calendar which combines Julian
+ * and Gregorian calendars.
+ * @return
+ * The number of micros since epoch from `java.sql.Timestamp`.
*/
def fromJavaTimestamp(t: Timestamp): Long =
rebaseJulianToGregorianMicros(fromJavaTimestampNoRebase(t))
@@ -297,30 +311,27 @@ trait SparkDateTimeUtils {
* Converts an instance of `java.sql.Timestamp` to the number of microseconds since
* 1970-01-01T00:00:00.000000Z.
*
- * @param t an instance of `java.sql.Timestamp`.
- * @return The number of micros since epoch from `java.sql.Timestamp`.
+ * @param t
+ * an instance of `java.sql.Timestamp`.
+ * @return
+ * The number of micros since epoch from `java.sql.Timestamp`.
*/
def fromJavaTimestampNoRebase(t: Timestamp): Long =
millisToMicros(t.getTime) + (t.getNanos / NANOS_PER_MICROS) % MICROS_PER_MILLIS
/**
- * Trims and parses a given UTF8 date string to a corresponding [[Int]] value.
- * The return type is [[Option]] in order to distinguish between 0 and null. The following
- * formats are allowed:
+ * Trims and parses a given UTF8 date string to a corresponding [[Int]] value. The return type
+ * is [[Option]] in order to distinguish between 0 and null. The following formats are allowed:
*
- * `[+-]yyyy*`
- * `[+-]yyyy*-[m]m`
- * `[+-]yyyy*-[m]m-[d]d`
- * `[+-]yyyy*-[m]m-[d]d `
- * `[+-]yyyy*-[m]m-[d]d *`
- * `[+-]yyyy*-[m]m-[d]dT*`
+ * `[+-]yyyy*` `[+-]yyyy*-[m]m` `[+-]yyyy*-[m]m-[d]d` `[+-]yyyy*-[m]m-[d]d `
+ * `[+-]yyyy*-[m]m-[d]d *` `[+-]yyyy*-[m]m-[d]dT*`
*/
def stringToDate(s: UTF8String): Option[Int] = {
def isValidDigits(segment: Int, digits: Int): Boolean = {
// An integer is able to represent a date within [+-]5 million years.
val maxDigitsYear = 7
(segment == 0 && digits >= 4 && digits <= maxDigitsYear) ||
- (segment != 0 && digits > 0 && digits <= 2)
+ (segment != 0 && digits > 0 && digits <= 2)
}
if (s == null) {
return None
@@ -380,24 +391,18 @@ trait SparkDateTimeUtils {
}
}
- def stringToDateAnsi(
- s: UTF8String,
- context: QueryContext = null): Int = {
+ def stringToDateAnsi(s: UTF8String, context: QueryContext = null): Int = {
stringToDate(s).getOrElse {
throw ExecutionErrors.invalidInputInCastToDatetimeError(s, DateType, context)
}
}
/**
- * Trims and parses a given UTF8 timestamp string to the corresponding timestamp segments,
- * time zone id and whether it is just time without a date.
- * value. The return type is [[Option]] in order to distinguish between 0L and null. The following
- * formats are allowed:
+ * Trims and parses a given UTF8 timestamp string to the corresponding timestamp segments, time
+ * zone id and whether it is just time without a date. value. The return type is [[Option]] in
+ * order to distinguish between 0L and null. The following formats are allowed:
*
- * `[+-]yyyy*`
- * `[+-]yyyy*-[m]m`
- * `[+-]yyyy*-[m]m-[d]d`
- * `[+-]yyyy*-[m]m-[d]d `
+ * `[+-]yyyy*` `[+-]yyyy*-[m]m` `[+-]yyyy*-[m]m-[d]d` `[+-]yyyy*-[m]m-[d]d `
* `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
* `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
* `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]`
@@ -407,16 +412,17 @@ trait SparkDateTimeUtils {
* - Z - Zulu time zone UTC+0
* - +|-[h]h:[m]m
* - A short id, see https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS
- * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-,
- * and a suffix in the formats:
+ * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, and a suffix in the
+ * formats:
* - +|-h[h]
* - +|-hh[:]mm
* - +|-hh:mm:ss
* - +|-hhmmss
- * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
+ * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris`
*
- * @return timestamp segments, time zone id and whether the input is just time without a date. If
- * the input string can't be parsed as timestamp, the result timestamp segments are empty.
+ * @return
+ * timestamp segments, time zone id and whether the input is just time without a date. If the
+ * input string can't be parsed as timestamp, the result timestamp segments are empty.
*/
def parseTimestampString(s: UTF8String): (Array[Int], Option[ZoneId], Boolean) = {
def isValidDigits(segment: Int, digits: Int): Boolean = {
@@ -424,9 +430,9 @@ trait SparkDateTimeUtils {
val maxDigitsYear = 6
// For the nanosecond part, more than 6 digits is allowed, but will be truncated.
segment == 6 || (segment == 0 && digits >= 4 && digits <= maxDigitsYear) ||
- // For the zoneId segment(7), it's could be zero digits when it's a region-based zone ID
- (segment == 7 && digits <= 2) ||
- (segment != 0 && segment != 6 && segment != 7 && digits > 0 && digits <= 2)
+ // For the zoneId segment(7), it's could be zero digits when it's a region-based zone ID
+ (segment == 7 && digits <= 2) ||
+ (segment != 0 && segment != 6 && segment != 7 && digits > 0 && digits <= 2)
}
if (s == null) {
return (Array.empty, None, false)
@@ -523,7 +529,7 @@ trait SparkDateTimeUtils {
tz = Some(new String(bytes, j, strEndTrimmed - j))
j = strEndTrimmed - 1
}
- if (i == 6 && b != '.') {
+ if (i == 6 && b != '.') {
i += 1
}
} else {
@@ -612,11 +618,11 @@ trait SparkDateTimeUtils {
*
* If the input string contains a component associated with time zone, the method will return
* `None` if `allowTimeZone` is set to `false`. If `allowTimeZone` is set to `true`, the method
- * will simply discard the time zone component. Enable the check to detect situations like parsing
- * a timestamp with time zone as TimestampNTZType.
+ * will simply discard the time zone component. Enable the check to detect situations like
+ * parsing a timestamp with time zone as TimestampNTZType.
*
- * The return type is [[Option]] in order to distinguish between 0L and null. Please
- * refer to `parseTimestampString` for the allowed formats.
+ * The return type is [[Option]] in order to distinguish between 0L and null. Please refer to
+ * `parseTimestampString` for the allowed formats.
*/
def stringToTimestampWithoutTimeZone(s: UTF8String, allowTimeZone: Boolean): Option[Long] = {
try {
@@ -637,10 +643,13 @@ trait SparkDateTimeUtils {
}
/**
- * Returns the index of the first non-whitespace and non-ISO control character in the byte array.
+ * Returns the index of the first non-whitespace and non-ISO control character in the byte
+ * array.
*
- * @param bytes The byte array to be processed.
- * @return The start index after trimming.
+ * @param bytes
+ * The byte array to be processed.
+ * @return
+ * The start index after trimming.
*/
@inline private def getTrimmedStart(bytes: Array[Byte]) = {
var start = 0
@@ -655,9 +664,12 @@ trait SparkDateTimeUtils {
/**
* Returns the index of the last non-whitespace and non-ISO control character in the byte array.
*
- * @param start The starting index for the search.
- * @param bytes The byte array to be processed.
- * @return The end index after trimming.
+ * @param start
+ * The starting index for the search.
+ * @param bytes
+ * The byte array to be processed.
+ * @return
+ * The end index after trimming.
*/
@inline private def getTrimmedEnd(start: Int, bytes: Array[Byte]) = {
var end = bytes.length - 1
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala
index 5e236187a4a0f..b8387b78ae3e2 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala
@@ -38,17 +38,16 @@ trait SparkIntervalUtils {
private final val minDurationSeconds = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND)
/**
- * Converts this duration to the total length in microseconds.
- *
- * If this duration is too large to fit in a [[Long]] microseconds, then an
- * exception is thrown.
- *
- * If this duration has greater than microsecond precision, then the conversion
- * will drop any excess precision information as though the amount in nanoseconds
- * was subject to integer division by one thousand.
+ * Converts this duration to the total length in microseconds.
If this duration is too large
+ * to fit in a [[Long]] microseconds, then an exception is thrown.
If this duration has
+ * greater than microsecond precision, then the conversion will drop any excess precision
+ * information as though the amount in nanoseconds was subject to integer division by one
+ * thousand.
*
- * @return The total length of the duration in microseconds
- * @throws ArithmeticException If numeric overflow occurs
+ * @return
+ * The total length of the duration in microseconds
+ * @throws ArithmeticException
+ * If numeric overflow occurs
*/
def durationToMicros(duration: Duration): Long = {
durationToMicros(duration, DT.SECOND)
@@ -59,7 +58,8 @@ trait SparkIntervalUtils {
val micros = if (seconds == minDurationSeconds) {
val microsInSeconds = (minDurationSeconds + 1) * MICROS_PER_SECOND
val nanoAdjustment = duration.getNano
- assert(0 <= nanoAdjustment && nanoAdjustment < NANOS_PER_SECOND,
+ assert(
+ 0 <= nanoAdjustment && nanoAdjustment < NANOS_PER_SECOND,
"Duration.getNano() must return the adjustment to the seconds field " +
"in the range from 0 to 999999999 nanoseconds, inclusive.")
Math.addExact(microsInSeconds, (nanoAdjustment - NANOS_PER_SECOND) / NANOS_PER_MICROS)
@@ -77,14 +77,13 @@ trait SparkIntervalUtils {
}
/**
- * Gets the total number of months in this period.
- *
- * This returns the total number of months in the period by multiplying the
- * number of years by 12 and adding the number of months.
- *
+ * Gets the total number of months in this period.
This returns the total number of months
+ * in the period by multiplying the number of years by 12 and adding the number of months.
*
- * @return The total number of months in the period, may be negative
- * @throws ArithmeticException If numeric overflow occurs
+ * @return
+ * The total number of months in the period, may be negative
+ * @throws ArithmeticException
+ * If numeric overflow occurs
*/
def periodToMonths(period: Period): Int = {
periodToMonths(period, YM.MONTH)
@@ -103,39 +102,41 @@ trait SparkIntervalUtils {
/**
* Obtains a [[Duration]] representing a number of microseconds.
*
- * @param micros The number of microseconds, positive or negative
- * @return A [[Duration]], not null
+ * @param micros
+ * The number of microseconds, positive or negative
+ * @return
+ * A [[Duration]], not null
*/
def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS)
/**
- * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years
- * and months units will be normalized.
+ * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the
+ * years and months units will be normalized.
*
- *
- * The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted
- * to compensate. For example, the method returns "2 years and 3 months" for the 27 input months.
- *
- * The sign of the years and months units will be the same after normalization.
- * For example, -13 months will be converted to "-1 year and -1 month".
+ *
The months unit is adjusted to have an absolute value < 12, with the years unit being
+ * adjusted to compensate. For example, the method returns "2 years and 3 months" for the 27
+ * input months.
The sign of the years and months units will be the same after
+ * normalization. For example, -13 months will be converted to "-1 year and -1 month".
*
- * @param months The number of months, positive or negative
- * @return The period of months, not null
+ * @param months
+ * The number of months, positive or negative
+ * @return
+ * The period of months, not null
*/
def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized()
/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
- * @throws IllegalArgumentException if the input string is not in valid interval format.
+ * @throws IllegalArgumentException
+ * if the input string is not in valid interval format.
*/
def stringToInterval(input: UTF8String): CalendarInterval = {
import ParseState._
if (input == null) {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_NULL",
- messageParameters = Map(
- "input" -> "null"))
+ messageParameters = Map("input" -> "null"))
}
// scalastyle:off caselocale .toLowerCase
val s = input.trimAll().toLowerCase
@@ -144,8 +145,7 @@ trait SparkIntervalUtils {
if (bytes.isEmpty) {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_EMPTY",
- messageParameters = Map(
- "input" -> input.toString))
+ messageParameters = Map("input" -> input.toString))
}
var state = PREFIX
var i = 0
@@ -182,14 +182,11 @@ trait SparkIntervalUtils {
if (s.numBytes() == intervalStr.numBytes()) {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INPUT_IS_EMPTY",
- messageParameters = Map(
- "input" -> input.toString))
+ messageParameters = Map("input" -> input.toString))
} else if (!Character.isWhitespace(bytes(i + intervalStr.numBytes()))) {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INVALID_PREFIX",
- messageParameters = Map(
- "input" -> input.toString,
- "prefix" -> currentWord))
+ messageParameters = Map("input" -> input.toString, "prefix" -> currentWord))
} else {
i += intervalStr.numBytes() + 1
}
@@ -224,11 +221,10 @@ trait SparkIntervalUtils {
pointPrefixed = true
i += 1
state = VALUE_FRACTIONAL_PART
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.UNRECOGNIZED_NUMBER",
- messageParameters = Map(
- "input" -> input.toString,
- "number" -> currentWord))
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.UNRECOGNIZED_NUMBER",
+ messageParameters = Map("input" -> input.toString, "number" -> currentWord))
}
case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE)
case VALUE =>
@@ -237,20 +233,19 @@ trait SparkIntervalUtils {
try {
currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0'))
} catch {
- case e: ArithmeticException => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION",
- messageParameters = Map(
- "input" -> input.toString))
+ case e: ArithmeticException =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION",
+ messageParameters = Map("input" -> input.toString))
}
case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_UNIT
case '.' =>
fractionScale = initialFractionScale
state = VALUE_FRACTIONAL_PART
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.INVALID_VALUE",
- messageParameters = Map(
- "input" -> input.toString,
- "value" -> currentWord))
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.INVALID_VALUE",
+ messageParameters = Map("input" -> input.toString, "value" -> currentWord))
}
i += 1
case VALUE_FRACTIONAL_PART =>
@@ -264,15 +259,11 @@ trait SparkIntervalUtils {
} else if ('0' <= b && b <= '9') {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INVALID_PRECISION",
- messageParameters = Map(
- "input" -> input.toString,
- "value" -> currentWord))
+ messageParameters = Map("input" -> input.toString, "value" -> currentWord))
} else {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INVALID_VALUE",
- messageParameters = Map(
- "input" -> input.toString,
- "value" -> currentWord))
+ messageParameters = Map("input" -> input.toString, "value" -> currentWord))
}
i += 1
case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN)
@@ -281,9 +272,7 @@ trait SparkIntervalUtils {
if (b != 's' && fractionScale >= 0) {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INVALID_FRACTION",
- messageParameters = Map(
- "input" -> input.toString,
- "unit" -> currentWord))
+ messageParameters = Map("input" -> input.toString, "unit" -> currentWord))
}
if (isNegative) {
currentValue = -currentValue
@@ -328,44 +317,38 @@ trait SparkIntervalUtils {
} else {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT",
- messageParameters = Map(
- "input" -> input.toString,
- "unit" -> currentWord))
+ messageParameters = Map("input" -> input.toString, "unit" -> currentWord))
}
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT",
- messageParameters = Map(
- "input" -> input.toString,
- "unit" -> currentWord))
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT",
+ messageParameters = Map("input" -> input.toString, "unit" -> currentWord))
}
} catch {
- case e: ArithmeticException => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION",
- messageParameters = Map(
- "input" -> input.toString))
+ case e: ArithmeticException =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.ARITHMETIC_EXCEPTION",
+ messageParameters = Map("input" -> input.toString))
}
state = UNIT_SUFFIX
case UNIT_SUFFIX =>
b match {
case 's' => state = UNIT_END
case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_SIGN
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT",
- messageParameters = Map(
- "input" -> input.toString,
- "unit" -> currentWord))
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT",
+ messageParameters = Map("input" -> input.toString, "unit" -> currentWord))
}
i += 1
case UNIT_END =>
- if (Character.isWhitespace(b) ) {
+ if (Character.isWhitespace(b)) {
i += 1
state = TRIM_BEFORE_SIGN
} else {
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.INVALID_UNIT",
- messageParameters = Map(
- "input" -> input.toString,
- "unit" -> currentWord))
+ messageParameters = Map("input" -> input.toString, "unit" -> currentWord))
}
}
}
@@ -373,36 +356,37 @@ trait SparkIntervalUtils {
val result = state match {
case UNIT_SUFFIX | UNIT_END | TRIM_BEFORE_SIGN =>
new CalendarInterval(months, days, microseconds)
- case TRIM_BEFORE_VALUE => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.MISSING_NUMBER",
- messageParameters = Map(
- "input" -> input.toString,
- "word" -> currentWord))
+ case TRIM_BEFORE_VALUE =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.MISSING_NUMBER",
+ messageParameters = Map("input" -> input.toString, "word" -> currentWord))
case VALUE | VALUE_FRACTIONAL_PART =>
throw new SparkIllegalArgumentException(
errorClass = "INVALID_INTERVAL_FORMAT.MISSING_UNIT",
- messageParameters = Map(
- "input" -> input.toString,
- "word" -> currentWord))
- case _ => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_INTERVAL_FORMAT.UNKNOWN_PARSING_ERROR",
- messageParameters = Map(
- "input" -> input.toString,
- "word" -> currentWord))
+ messageParameters = Map("input" -> input.toString, "word" -> currentWord))
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_INTERVAL_FORMAT.UNKNOWN_PARSING_ERROR",
+ messageParameters = Map("input" -> input.toString, "word" -> currentWord))
}
result
}
/**
- * Converts an year-month interval as a number of months to its textual representation
- * which conforms to the ANSI SQL standard.
+ * Converts an year-month interval as a number of months to its textual representation which
+ * conforms to the ANSI SQL standard.
*
- * @param months The number of months, positive or negative
- * @param style The style of textual representation of the interval
- * @param startField The start field (YEAR or MONTH) which the interval comprises of.
- * @param endField The end field (YEAR or MONTH) which the interval comprises of.
- * @return Year-month interval string
+ * @param months
+ * The number of months, positive or negative
+ * @param style
+ * The style of textual representation of the interval
+ * @param startField
+ * The start field (YEAR or MONTH) which the interval comprises of.
+ * @param endField
+ * The end field (YEAR or MONTH) which the interval comprises of.
+ * @return
+ * Year-month interval string
*/
def toYearMonthIntervalString(
months: Int,
@@ -434,14 +418,19 @@ trait SparkIntervalUtils {
}
/**
- * Converts a day-time interval as a number of microseconds to its textual representation
- * which conforms to the ANSI SQL standard.
+ * Converts a day-time interval as a number of microseconds to its textual representation which
+ * conforms to the ANSI SQL standard.
*
- * @param micros The number of microseconds, positive or negative
- * @param style The style of textual representation of the interval
- * @param startField The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of.
- * @param endField The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of.
- * @return Day-time interval string
+ * @param micros
+ * The number of microseconds, positive or negative
+ * @param style
+ * The style of textual representation of the interval
+ * @param startField
+ * The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of.
+ * @param endField
+ * The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of.
+ * @return
+ * Day-time interval string
*/
def toDayTimeIntervalString(
micros: Long,
@@ -514,8 +503,9 @@ trait SparkIntervalUtils {
rest %= MICROS_PER_MINUTE
case DT.SECOND =>
val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else ""
- formatBuilder.append(s"$leadZero" +
- s"${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}")
+ formatBuilder.append(
+ s"$leadZero" +
+ s"${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}")
}
if (startField < DT.HOUR && DT.HOUR <= endField) {
@@ -565,20 +555,11 @@ trait SparkIntervalUtils {
protected val microsStr: UTF8String = unitToUtf8("microsecond")
protected val nanosStr: UTF8String = unitToUtf8("nanosecond")
-
private object ParseState extends Enumeration {
type ParseState = Value
- val PREFIX,
- TRIM_BEFORE_SIGN,
- SIGN,
- TRIM_BEFORE_VALUE,
- VALUE,
- VALUE_FRACTIONAL_PART,
- TRIM_BEFORE_UNIT,
- UNIT_BEGIN,
- UNIT_SUFFIX,
- UNIT_END = Value
+ val PREFIX, TRIM_BEFORE_SIGN, SIGN, TRIM_BEFORE_VALUE, VALUE, VALUE_FRACTIONAL_PART,
+ TRIM_BEFORE_UNIT, UNIT_BEGIN, UNIT_SUFFIX, UNIT_END = Value
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
index 7597cb1d9087d..01ee899085701 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala
@@ -62,8 +62,8 @@ trait SparkParserUtils {
val secondChar = s.charAt(start + 1)
val thirdChar = s.charAt(start + 2)
(firstChar == '0' || firstChar == '1') &&
- (secondChar >= '0' && secondChar <= '7') &&
- (thirdChar >= '0' && thirdChar <= '7')
+ (secondChar >= '0' && secondChar <= '7') &&
+ (thirdChar >= '0' && thirdChar <= '7')
}
val isRawString = {
@@ -97,15 +97,18 @@ trait SparkParserUtils {
// \u0000 style 16-bit unicode character literals.
sb.append(Integer.parseInt(b, i + 1, i + 1 + 4, 16).toChar)
i += 1 + 4
- } else if (cAfterBackslash == 'U' && i + 1 + 8 <= length && allCharsAreHex(b, i + 1, 8)) {
+ } else if (cAfterBackslash == 'U' && i + 1 + 8 <= length && allCharsAreHex(
+ b,
+ i + 1,
+ 8)) {
// \U00000000 style 32-bit unicode character literals.
// Use Long to treat codePoint as unsigned in the range of 32-bit.
val codePoint = JLong.parseLong(b, i + 1, i + 1 + 8, 16)
if (codePoint < 0x10000) {
- sb.append((codePoint & 0xFFFF).toChar)
+ sb.append((codePoint & 0xffff).toChar)
} else {
- val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xD800
- val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xDC00
+ val highSurrogate = (codePoint - 0x10000) / 0x400 + 0xd800
+ val lowSurrogate = (codePoint - 0x10000) % 0x400 + 0xdc00
sb.append(highSurrogate.toChar)
sb.append(lowSurrogate.toChar)
}
@@ -147,8 +150,13 @@ trait SparkParserUtils {
if (text.isEmpty) {
CurrentOrigin.set(position(ctx.getStart))
} else {
- CurrentOrigin.set(positionAndText(ctx.getStart, ctx.getStop, text.get,
- current.objectType, current.objectName))
+ CurrentOrigin.set(
+ positionAndText(
+ ctx.getStart,
+ ctx.getStop,
+ text.get,
+ current.objectType,
+ current.objectName))
}
try {
f
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index edb1ee371b156..0608322be13b3 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -24,9 +24,8 @@ import org.apache.spark.unsafe.array.ByteArrayUtils
import org.apache.spark.util.ArrayImplicits._
/**
- * Concatenation of sequence of strings to final string with cheap append method
- * and one memory allocation for the final string. Can also bound the final size of
- * the string.
+ * Concatenation of sequence of strings to final string with cheap append method and one memory
+ * allocation for the final string. Can also bound the final size of the string.
*/
class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH) {
protected val strings = new java.util.ArrayList[String]
@@ -35,9 +34,9 @@ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH)
def atLimit: Boolean = length >= maxLength
/**
- * Appends a string and accumulates its length to allocate a string buffer for all
- * appended strings once in the toString method. Returns true if the string still
- * has room for further appends before it hits its max limit.
+ * Appends a string and accumulates its length to allocate a string buffer for all appended
+ * strings once in the toString method. Returns true if the string still has room for further
+ * appends before it hits its max limit.
*/
def append(s: String): Unit = {
if (s != null) {
@@ -56,8 +55,8 @@ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH)
}
/**
- * The method allocates memory for all appended strings, writes them to the memory and
- * returns concatenated string.
+ * The method allocates memory for all appended strings, writes them to the memory and returns
+ * concatenated string.
*/
override def toString: String = {
val finalLength = if (atLimit) maxLength else length
@@ -68,6 +67,7 @@ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH)
}
object SparkStringUtils extends Logging {
+
/** Whether we have warned about plan string truncation yet. */
private val truncationWarningPrinted = new AtomicBoolean(false)
@@ -75,7 +75,8 @@ object SparkStringUtils extends Logging {
* Format a sequence with semantics similar to calling .mkString(). Any elements beyond
* maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder.
*
- * @return the trimmed and formatted string.
+ * @return
+ * the trimmed and formatted string.
*/
def truncatedString[T](
seq: Seq[T],
@@ -90,8 +91,9 @@ object SparkStringUtils extends Logging {
s"behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.")
}
val numFields = math.max(0, maxFields - 1)
- seq.take(numFields).mkString(
- start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end)
+ seq
+ .take(numFields)
+ .mkString(start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end)
} else {
seq.mkString(start, sep, end)
}
@@ -106,8 +108,8 @@ object SparkStringUtils extends Logging {
HexFormat.of().withDelimiter(" ").withUpperCase()
/**
- * Returns a pretty string of the byte array which prints each byte as a hex digit and add spaces
- * between them. For example, [1A C0].
+ * Returns a pretty string of the byte array which prints each byte as a hex digit and add
+ * spaces between them. For example, [1A C0].
*/
def getHexString(bytes: Array[Byte]): String = {
s"[${SPACE_DELIMITED_UPPERCASE_HEX.formatHex(bytes)}]"
@@ -122,8 +124,8 @@ object SparkStringUtils extends Logging {
val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("")
val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("")
- leftPadded.zip(rightPadded).map {
- case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r
+ leftPadded.zip(rightPadded).map { case (l, r) =>
+ (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r
}
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
index 79d627b493fd8..4fcb84daf187d 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
@@ -41,14 +41,20 @@ import org.apache.spark.sql.types.{Decimal, TimestampNTZType}
import org.apache.spark.unsafe.types.UTF8String
sealed trait TimestampFormatter extends Serializable {
+
/**
* Parses a timestamp in a string and converts it to microseconds.
*
- * @param s - string with timestamp to parse
- * @return microseconds since epoch.
- * @throws ParseException can be thrown by legacy parser
- * @throws DateTimeParseException can be thrown by new parser
- * @throws DateTimeException unable to obtain local date or time
+ * @param s
+ * \- string with timestamp to parse
+ * @return
+ * microseconds since epoch.
+ * @throws ParseException
+ * can be thrown by legacy parser
+ * @throws DateTimeParseException
+ * can be thrown by new parser
+ * @throws DateTimeException
+ * unable to obtain local date or time
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@@ -58,11 +64,16 @@ sealed trait TimestampFormatter extends Serializable {
/**
* Parses a timestamp in a string and converts it to an optional number of microseconds.
*
- * @param s - string with timestamp to parse
- * @return An optional number of microseconds since epoch. The result is None on invalid input.
- * @throws ParseException can be thrown by legacy parser
- * @throws DateTimeParseException can be thrown by new parser
- * @throws DateTimeException unable to obtain local date or time
+ * @param s
+ * \- string with timestamp to parse
+ * @return
+ * An optional number of microseconds since epoch. The result is None on invalid input.
+ * @throws ParseException
+ * can be thrown by legacy parser
+ * @throws DateTimeParseException
+ * can be thrown by new parser
+ * @throws DateTimeException
+ * unable to obtain local date or time
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@@ -75,16 +86,24 @@ sealed trait TimestampFormatter extends Serializable {
}
/**
- * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time.
+ * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local
+ * time.
*
- * @param s - string with timestamp to parse
- * @param allowTimeZone - indicates strict parsing of timezone
- * @return microseconds since epoch.
- * @throws ParseException can be thrown by legacy parser
- * @throws DateTimeParseException can be thrown by new parser
- * @throws DateTimeException unable to obtain local date or time
- * @throws IllegalStateException The formatter for timestamp without time zone should always
- * implement this method. The exception should never be hit.
+ * @param s
+ * \- string with timestamp to parse
+ * @param allowTimeZone
+ * \- indicates strict parsing of timezone
+ * @return
+ * microseconds since epoch.
+ * @throws ParseException
+ * can be thrown by legacy parser
+ * @throws DateTimeParseException
+ * can be thrown by new parser
+ * @throws DateTimeException
+ * unable to obtain local date or time
+ * @throws IllegalStateException
+ * The formatter for timestamp without time zone should always implement this method. The
+ * exception should never be hit.
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@@ -99,14 +118,21 @@ sealed trait TimestampFormatter extends Serializable {
* Parses a timestamp in a string and converts it to an optional number of microseconds since
* Unix Epoch in local time.
*
- * @param s - string with timestamp to parse
- * @param allowTimeZone - indicates strict parsing of timezone
- * @return An optional number of microseconds since epoch. The result is None on invalid input.
- * @throws ParseException can be thrown by legacy parser
- * @throws DateTimeParseException can be thrown by new parser
- * @throws DateTimeException unable to obtain local date or time
- * @throws IllegalStateException The formatter for timestamp without time zone should always
- * implement this method. The exception should never be hit.
+ * @param s
+ * \- string with timestamp to parse
+ * @param allowTimeZone
+ * \- indicates strict parsing of timezone
+ * @return
+ * An optional number of microseconds since epoch. The result is None on invalid input.
+ * @throws ParseException
+ * can be thrown by legacy parser
+ * @throws DateTimeParseException
+ * can be thrown by new parser
+ * @throws DateTimeException
+ * unable to obtain local date or time
+ * @throws IllegalStateException
+ * The formatter for timestamp without time zone should always implement this method. The
+ * exception should never be hit.
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@@ -120,8 +146,8 @@ sealed trait TimestampFormatter extends Serializable {
}
/**
- * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time.
- * Zone-id and zone-offset components are ignored.
+ * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local
+ * time. Zone-id and zone-offset components are ignored.
*/
@throws(classOf[ParseException])
@throws(classOf[DateTimeParseException])
@@ -144,9 +170,10 @@ sealed trait TimestampFormatter extends Serializable {
/**
* Validates the pattern string.
- * @param checkLegacy if true and the pattern is invalid, check whether the pattern is valid for
- * legacy formatters and show hints for using legacy formatter.
- * Otherwise, simply check the pattern string.
+ * @param checkLegacy
+ * if true and the pattern is invalid, check whether the pattern is valid for legacy
+ * formatters and show hints for using legacy formatter. Otherwise, simply check the pattern
+ * string.
*/
def validatePatternString(checkLegacy: Boolean): Unit
}
@@ -157,7 +184,8 @@ class Iso8601TimestampFormatter(
locale: Locale,
legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT,
isParsing: Boolean)
- extends TimestampFormatter with DateTimeFormatterHelper {
+ extends TimestampFormatter
+ with DateTimeFormatterHelper {
@transient
protected lazy val formatter: DateTimeFormatter =
getOrCreateFormatter(pattern, locale, isParsing)
@@ -166,8 +194,8 @@ class Iso8601TimestampFormatter(
private lazy val zonedFormatter: DateTimeFormatter = formatter.withZone(zoneId)
@transient
- protected lazy val legacyFormatter = TimestampFormatter.getLegacyFormatter(
- pattern, zoneId, locale, legacyFormat)
+ protected lazy val legacyFormatter =
+ TimestampFormatter.getLegacyFormatter(pattern, zoneId, locale, legacyFormat)
override def parseOptional(s: String): Option[Long] = {
try {
@@ -235,8 +263,8 @@ class Iso8601TimestampFormatter(
override def format(instant: Instant): String = {
try {
zonedFormatter.format(instant)
- } catch checkFormattedDiff(toJavaTimestamp(instantToMicros(instant)),
- (t: Timestamp) => format(t))
+ } catch
+ checkFormattedDiff(toJavaTimestamp(instantToMicros(instant)), (t: Timestamp) => format(t))
}
override def format(us: Long): String = {
@@ -256,8 +284,8 @@ class Iso8601TimestampFormatter(
if (checkLegacy) {
try {
formatter
- } catch checkLegacyFormatter(pattern,
- legacyFormatter.validatePatternString(checkLegacy = true))
+ } catch
+ checkLegacyFormatter(pattern, legacyFormatter.validatePatternString(checkLegacy = true))
()
} else {
try {
@@ -268,22 +296,30 @@ class Iso8601TimestampFormatter(
}
/**
- * The formatter for timestamps which doesn't require users to specify a pattern. While formatting,
- * it uses the default pattern [[TimestampFormatter.defaultPattern()]]. In parsing, it follows
- * the CAST logic in conversion of strings to Catalyst's TimestampType.
+ * The formatter for timestamps which doesn't require users to specify a pattern. While
+ * formatting, it uses the default pattern [[TimestampFormatter.defaultPattern()]]. In parsing, it
+ * follows the CAST logic in conversion of strings to Catalyst's TimestampType.
*
- * @param zoneId The time zone ID in which timestamps should be formatted or parsed.
- * @param locale The locale overrides the system locale and is used in formatting.
- * @param legacyFormat Defines the formatter used for legacy timestamps.
- * @param isParsing Whether the formatter is used for parsing (`true`) or for formatting (`false`).
+ * @param zoneId
+ * The time zone ID in which timestamps should be formatted or parsed.
+ * @param locale
+ * The locale overrides the system locale and is used in formatting.
+ * @param legacyFormat
+ * Defines the formatter used for legacy timestamps.
+ * @param isParsing
+ * Whether the formatter is used for parsing (`true`) or for formatting (`false`).
*/
class DefaultTimestampFormatter(
zoneId: ZoneId,
locale: Locale,
legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT,
isParsing: Boolean)
- extends Iso8601TimestampFormatter(
- TimestampFormatter.defaultPattern(), zoneId, locale, legacyFormat, isParsing) {
+ extends Iso8601TimestampFormatter(
+ TimestampFormatter.defaultPattern(),
+ zoneId,
+ locale,
+ legacyFormat,
+ isParsing) {
override def parse(s: String): Long = {
try {
@@ -299,7 +335,9 @@ class DefaultTimestampFormatter(
val utf8Value = UTF8String.fromString(s)
SparkDateTimeUtils.stringToTimestampWithoutTimeZone(utf8Value, allowTimeZone).getOrElse {
throw ExecutionErrors.cannotParseStringAsDataTypeError(
- TimestampFormatter.defaultPattern(), s, TimestampNTZType)
+ TimestampFormatter.defaultPattern(),
+ s,
+ TimestampNTZType)
}
} catch checkParsedDiff(s, legacyFormatter.parse)
}
@@ -311,20 +349,21 @@ class DefaultTimestampFormatter(
}
/**
- * The formatter parses/formats timestamps according to the pattern `yyyy-MM-dd HH:mm:ss.[..fff..]`
- * where `[..fff..]` is a fraction of second up to microsecond resolution. The formatter does not
- * output trailing zeros in the fraction. For example, the timestamp `2019-03-05 15:00:01.123400` is
- * formatted as the string `2019-03-05 15:00:01.1234`.
+ * The formatter parses/formats timestamps according to the pattern `yyyy-MM-dd
+ * HH:mm:ss.[..fff..]` where `[..fff..]` is a fraction of second up to microsecond resolution. The
+ * formatter does not output trailing zeros in the fraction. For example, the timestamp
+ * `2019-03-05 15:00:01.123400` is formatted as the string `2019-03-05 15:00:01.1234`.
*
- * @param zoneId the time zone identifier in which the formatter parses or format timestamps
+ * @param zoneId
+ * the time zone identifier in which the formatter parses or format timestamps
*/
class FractionTimestampFormatter(zoneId: ZoneId)
- extends Iso8601TimestampFormatter(
- TimestampFormatter.defaultPattern(),
- zoneId,
- TimestampFormatter.defaultLocale,
- LegacyDateFormats.FAST_DATE_FORMAT,
- isParsing = false) {
+ extends Iso8601TimestampFormatter(
+ TimestampFormatter.defaultPattern(),
+ zoneId,
+ TimestampFormatter.defaultLocale,
+ LegacyDateFormats.FAST_DATE_FORMAT,
+ isParsing = false) {
@transient
override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter
@@ -366,16 +405,14 @@ class FractionTimestampFormatter(zoneId: ZoneId)
}
/**
- * The custom sub-class of `GregorianCalendar` is needed to get access to
- * protected `fields` immediately after parsing. We cannot use
- * the `get()` method because it performs normalization of the fraction
- * part. Accordingly, the `MILLISECOND` field doesn't contain original value.
+ * The custom sub-class of `GregorianCalendar` is needed to get access to protected `fields`
+ * immediately after parsing. We cannot use the `get()` method because it performs normalization
+ * of the fraction part. Accordingly, the `MILLISECOND` field doesn't contain original value.
*
- * Also this class allows to set raw value to the `MILLISECOND` field
- * directly before formatting.
+ * Also this class allows to set raw value to the `MILLISECOND` field directly before formatting.
*/
class MicrosCalendar(tz: TimeZone, digitsInFraction: Int)
- extends GregorianCalendar(tz, Locale.US) {
+ extends GregorianCalendar(tz, Locale.US) {
// Converts parsed `MILLISECOND` field to seconds fraction in microsecond precision.
// For example if the fraction pattern is `SSSS` then `digitsInFraction` = 4, and
// if the `MILLISECOND` field was parsed to `1234`.
@@ -397,16 +434,13 @@ class MicrosCalendar(tz: TimeZone, digitsInFraction: Int)
}
}
-class LegacyFastTimestampFormatter(
- pattern: String,
- zoneId: ZoneId,
- locale: Locale) extends TimestampFormatter {
+class LegacyFastTimestampFormatter(pattern: String, zoneId: ZoneId, locale: Locale)
+ extends TimestampFormatter {
@transient private lazy val fastDateFormat =
FastDateFormat.getInstance(pattern, TimeZone.getTimeZone(zoneId), locale)
- @transient private lazy val cal = new MicrosCalendar(
- fastDateFormat.getTimeZone,
- fastDateFormat.getPattern.count(_ == 'S'))
+ @transient private lazy val cal =
+ new MicrosCalendar(fastDateFormat.getTimeZone, fastDateFormat.getPattern.count(_ == 'S'))
override def parse(s: String): Long = {
cal.clear() // Clear the calendar because it can be re-used many times
@@ -464,7 +498,8 @@ class LegacySimpleTimestampFormatter(
pattern: String,
zoneId: ZoneId,
locale: Locale,
- lenient: Boolean = true) extends TimestampFormatter {
+ lenient: Boolean = true)
+ extends TimestampFormatter {
@transient private lazy val sdf = {
val formatter = new SimpleDateFormat(pattern, locale)
formatter.setTimeZone(TimeZone.getTimeZone(zoneId))
@@ -586,14 +621,10 @@ object TimestampFormatter {
legacyFormat: LegacyDateFormat,
isParsing: Boolean,
forTimestampNTZ: Boolean): TimestampFormatter = {
- getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, isParsing,
- forTimestampNTZ)
+ getFormatter(Some(format), zoneId, defaultLocale, legacyFormat, isParsing, forTimestampNTZ)
}
- def apply(
- format: String,
- zoneId: ZoneId,
- isParsing: Boolean): TimestampFormatter = {
+ def apply(format: String, zoneId: ZoneId, isParsing: Boolean): TimestampFormatter = {
getFormatter(Some(format), zoneId, isParsing = isParsing)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala
index a98aa26d02ef7..73ab43f04a5a0 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/UDTUtils.scala
@@ -28,11 +28,12 @@ import org.apache.spark.util.SparkClassUtils
* with catalyst because they (amongst others) require access to Spark SQLs internal data
* representation.
*
- * This interface and its companion object provide an escape hatch for working with UDTs from within
- * the api project (e.g. Row.toJSON). The companion will try to bind to an implementation of the
- * interface in catalyst, if none is found it will bind to [[DefaultUDTUtils]].
+ * This interface and its companion object provide an escape hatch for working with UDTs from
+ * within the api project (e.g. Row.toJSON). The companion will try to bind to an implementation
+ * of the interface in catalyst, if none is found it will bind to [[DefaultUDTUtils]].
*/
private[sql] trait UDTUtils {
+
/**
* Convert the UDT instance to something that is compatible with [[org.apache.spark.sql.Row]].
* The returned value must conform to the schema of the UDT.
@@ -41,13 +42,14 @@ private[sql] trait UDTUtils {
}
private[sql] object UDTUtils extends UDTUtils {
- private val delegate = try {
- val cls = SparkClassUtils.classForName("org.apache.spark.sql.catalyst.util.UDTUtilsImpl")
- cls.getConstructor().newInstance().asInstanceOf[UDTUtils]
- } catch {
- case NonFatal(_) =>
- DefaultUDTUtils
- }
+ private val delegate =
+ try {
+ val cls = SparkClassUtils.classForName("org.apache.spark.sql.catalyst.util.UDTUtilsImpl")
+ cls.getConstructor().newInstance().asInstanceOf[UDTUtils]
+ } catch {
+ case NonFatal(_) =>
+ DefaultUDTUtils
+ }
override def toRow(value: Any, udt: UserDefinedType[Any]): Any = delegate.toRow(value, udt)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala
index 6034c41906313..3e63b8281f739 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala
@@ -23,9 +23,7 @@ private[sql] trait CompilationErrors extends DataTypeErrorsBase {
def ambiguousColumnOrFieldError(name: Seq[String], numMatches: Int): AnalysisException = {
new AnalysisException(
errorClass = "AMBIGUOUS_COLUMN_OR_FIELD",
- messageParameters = Map(
- "name" -> toSQLId(name),
- "n" -> numMatches.toString))
+ messageParameters = Map("name" -> toSQLId(name), "n" -> numMatches.toString))
}
def columnNotFoundError(colName: String): AnalysisException = {
@@ -51,9 +49,7 @@ private[sql] trait CompilationErrors extends DataTypeErrorsBase {
}
def usingUntypedScalaUDFError(): Throwable = {
- new AnalysisException(
- errorClass = "UNTYPED_SCALA_UDF",
- messageParameters = Map.empty)
+ new AnalysisException(errorClass = "UNTYPED_SCALA_UDF", messageParameters = Map.empty)
}
def invalidBoundaryStartError(start: Long): Throwable = {
@@ -81,14 +77,11 @@ private[sql] trait CompilationErrors extends DataTypeErrorsBase {
def invalidSaveModeError(saveMode: String): Throwable = {
new AnalysisException(
errorClass = "INVALID_SAVE_MODE",
- messageParameters = Map("mode" -> toDSOption(saveMode))
- )
+ messageParameters = Map("mode" -> toDSOption(saveMode)))
}
def sortByWithoutBucketingError(): Throwable = {
- new AnalysisException(
- errorClass = "SORT_BY_WITHOUT_BUCKETING",
- messageParameters = Map.empty)
+ new AnalysisException(errorClass = "SORT_BY_WITHOUT_BUCKETING", messageParameters = Map.empty)
}
def bucketByUnsupportedByOperationError(operation: String): Throwable = {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
index 32b4198dc1a63..388a98569258b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala
@@ -25,9 +25,9 @@ import org.apache.spark.sql.types.{DataType, Decimal, StringType}
import org.apache.spark.unsafe.types.UTF8String
/**
- * Object for grouping error messages from (most) exceptions thrown during query execution.
- * This does not include exceptions thrown during the eager execution of commands, which are
- * grouped into [[CompilationErrors]].
+ * Object for grouping error messages from (most) exceptions thrown during query execution. This
+ * does not include exceptions thrown during the eager execution of commands, which are grouped
+ * into [[CompilationErrors]].
*/
private[sql] object DataTypeErrors extends DataTypeErrorsBase {
def unsupportedOperationExceptionError(): SparkUnsupportedOperationException = {
@@ -35,13 +35,12 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
}
def decimalPrecisionExceedsMaxPrecisionError(
- precision: Int, maxPrecision: Int): SparkArithmeticException = {
+ precision: Int,
+ maxPrecision: Int): SparkArithmeticException = {
new SparkArithmeticException(
errorClass = "DECIMAL_PRECISION_EXCEEDS_MAX_PRECISION",
- messageParameters = Map(
- "precision" -> precision.toString,
- "maxPrecision" -> maxPrecision.toString
- ),
+ messageParameters =
+ Map("precision" -> precision.toString, "maxPrecision" -> maxPrecision.toString),
context = Array.empty,
summary = "")
}
@@ -53,8 +52,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
def outOfDecimalTypeRangeError(str: UTF8String): SparkArithmeticException = {
new SparkArithmeticException(
errorClass = "NUMERIC_OUT_OF_SUPPORTED_RANGE",
- messageParameters = Map(
- "value" -> str.toString),
+ messageParameters = Map("value" -> str.toString),
context = Array.empty,
summary = "")
}
@@ -68,25 +66,20 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
def nullLiteralsCannotBeCastedError(name: String): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_2226",
- messageParameters = Map(
- "name" -> name))
+ messageParameters = Map("name" -> name))
}
def notUserDefinedTypeError(name: String, userClass: String): Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2227",
- messageParameters = Map(
- "name" -> name,
- "userClass" -> userClass),
+ messageParameters = Map("name" -> name, "userClass" -> userClass),
cause = null)
}
def cannotLoadUserDefinedTypeError(name: String, userClass: String): Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2228",
- messageParameters = Map(
- "name" -> name,
- "userClass" -> userClass),
+ messageParameters = Map("name" -> name, "userClass" -> userClass),
cause = null)
}
@@ -99,50 +92,42 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
def schemaFailToParseError(schema: String, e: Throwable): Throwable = {
new AnalysisException(
errorClass = "INVALID_SCHEMA.PARSE_ERROR",
- messageParameters = Map(
- "inputSchema" -> toSQLSchema(schema),
- "reason" -> e.getMessage
- ),
+ messageParameters = Map("inputSchema" -> toSQLSchema(schema), "reason" -> e.getMessage),
cause = Some(e))
}
def invalidDayTimeIntervalType(startFieldName: String, endFieldName: String): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1224",
- messageParameters = Map(
- "startFieldName" -> startFieldName,
- "endFieldName" -> endFieldName))
+ messageParameters = Map("startFieldName" -> startFieldName, "endFieldName" -> endFieldName))
}
def invalidDayTimeField(field: Byte, supportedIds: Seq[String]): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1223",
- messageParameters = Map(
- "field" -> field.toString,
- "supportedIds" -> supportedIds.mkString(", ")))
+ messageParameters =
+ Map("field" -> field.toString, "supportedIds" -> supportedIds.mkString(", ")))
}
def invalidYearMonthField(field: Byte, supportedIds: Seq[String]): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1225",
- messageParameters = Map(
- "field" -> field.toString,
- "supportedIds" -> supportedIds.mkString(", ")))
+ messageParameters =
+ Map("field" -> field.toString, "supportedIds" -> supportedIds.mkString(", ")))
}
def decimalCannotGreaterThanPrecisionError(scale: Int, precision: Int): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1228",
- messageParameters = Map(
- "scale" -> scale.toString,
- "precision" -> precision.toString))
+ messageParameters = Map("scale" -> scale.toString, "precision" -> precision.toString))
}
def negativeScaleNotAllowedError(scale: Int): Throwable = {
val sqlConf = QuotingUtils.toSQLConf("spark.sql.legacy.allowNegativeScaleOfDecimal")
- SparkException.internalError(s"Negative scale is not allowed: ${scale.toString}." +
- s" Set the config ${sqlConf}" +
- " to \"true\" to allow it.")
+ SparkException.internalError(
+ s"Negative scale is not allowed: ${scale.toString}." +
+ s" Set the config ${sqlConf}" +
+ " to \"true\" to allow it.")
}
def attributeNameSyntaxError(name: String): Throwable = {
@@ -154,19 +139,17 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = {
new SparkException(
errorClass = "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE",
- messageParameters = Map(
- "left" -> toSQLType(left),
- "right" -> toSQLType(right)),
+ messageParameters = Map("left" -> toSQLType(left), "right" -> toSQLType(right)),
cause = null)
}
def cannotMergeDecimalTypesWithIncompatibleScaleError(
- leftScale: Int, rightScale: Int): Throwable = {
+ leftScale: Int,
+ rightScale: Int): Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2124",
- messageParameters = Map(
- "leftScale" -> leftScale.toString(),
- "rightScale" -> rightScale.toString()),
+ messageParameters =
+ Map("leftScale" -> leftScale.toString(), "rightScale" -> rightScale.toString()),
cause = null)
}
@@ -179,9 +162,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
def invalidFieldName(fieldName: Seq[String], path: Seq[String], context: Origin): Throwable = {
new AnalysisException(
errorClass = "INVALID_FIELD_NAME",
- messageParameters = Map(
- "fieldName" -> toSQLId(fieldName),
- "path" -> toSQLId(path)),
+ messageParameters = Map("fieldName" -> toSQLId(fieldName), "path" -> toSQLId(path)),
origin = context)
}
@@ -227,30 +208,26 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
messageParameters = Map(
"expression" -> convertedValueStr,
"sourceType" -> toSQLType(StringType),
- "targetType" -> toSQLType(to),
- "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
+ "targetType" -> toSQLType(to)),
context = getQueryContext(context),
summary = getSummary(context))
}
def ambiguousColumnOrFieldError(
- name: Seq[String], numMatches: Int, context: Origin): Throwable = {
+ name: Seq[String],
+ numMatches: Int,
+ context: Origin): Throwable = {
new AnalysisException(
errorClass = "AMBIGUOUS_COLUMN_OR_FIELD",
- messageParameters = Map(
- "name" -> toSQLId(name),
- "n" -> numMatches.toString),
+ messageParameters = Map("name" -> toSQLId(name), "n" -> numMatches.toString),
origin = context)
}
def castingCauseOverflowError(t: String, from: DataType, to: DataType): ArithmeticException = {
new SparkArithmeticException(
errorClass = "CAST_OVERFLOW",
- messageParameters = Map(
- "value" -> t,
- "sourceType" -> toSQLType(from),
- "targetType" -> toSQLType(to),
- "ansiConfig" -> toSQLConf("spark.sql.ansi.enabled")),
+ messageParameters =
+ Map("value" -> t, "sourceType" -> toSQLType(from), "targetType" -> toSQLType(to)),
context = Array.empty,
summary = "")
}
@@ -267,15 +244,13 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase {
messageParameters = Map(
"methodName" -> "fieldIndex",
"className" -> "Row",
- "fieldName" -> toSQLId(fieldName))
- )
+ "fieldName" -> toSQLId(fieldName)))
}
def valueIsNullError(index: Int): Throwable = {
new SparkRuntimeException(
errorClass = "ROW_VALUE_IS_NULL",
- messageParameters = Map(
- "index" -> index.toString),
+ messageParameters = Map("index" -> index.toString),
cause = null)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index 7910c386fcf14..698a7b096e1a5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -60,7 +60,8 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
}
def failToRecognizePatternAfterUpgradeError(
- pattern: String, e: Throwable): SparkUpgradeException = {
+ pattern: String,
+ e: Throwable): SparkUpgradeException = {
new SparkUpgradeException(
errorClass = "INCONSISTENT_BEHAVIOR_CROSS_VERSION.DATETIME_PATTERN_RECOGNITION",
messageParameters = Map(
@@ -73,9 +74,8 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
def failToRecognizePatternError(pattern: String, e: Throwable): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "_LEGACY_ERROR_TEMP_2130",
- messageParameters = Map(
- "pattern" -> toSQLValue(pattern),
- "docroot" -> SparkBuildInfo.spark_doc_root),
+ messageParameters =
+ Map("pattern" -> toSQLValue(pattern), "docroot" -> SparkBuildInfo.spark_doc_root),
cause = e)
}
@@ -93,9 +93,9 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
}
def invalidInputInCastToDatetimeError(
- value: Double,
- to: DataType,
- context: QueryContext): SparkDateTimeException = {
+ value: Double,
+ to: DataType,
+ context: QueryContext): SparkDateTimeException = {
invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), DoubleType, to, context)
}
@@ -109,8 +109,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
messageParameters = Map(
"expression" -> sqlValue,
"sourceType" -> toSQLType(from),
- "targetType" -> toSQLType(to),
- "ansiConfig" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)),
+ "targetType" -> toSQLType(to)),
context = getQueryContext(context),
summary = getSummary(context))
}
@@ -132,8 +131,10 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
summary = getSummary(context))
}
- def cannotParseStringAsDataTypeError(pattern: String, value: String, dataType: DataType)
- : Throwable = {
+ def cannotParseStringAsDataTypeError(
+ pattern: String,
+ value: String,
+ dataType: DataType): Throwable = {
SparkException.internalError(
s"Cannot parse field value ${toSQLValue(value)} for pattern ${toSQLValue(pattern)} " +
s"as the target spark data type ${toSQLType(dataType)}.")
@@ -161,17 +162,14 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
def userDefinedTypeNotAnnotatedAndRegisteredError(udt: UserDefinedType[_]): Throwable = {
new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2155",
- messageParameters = Map(
- "userClass" -> udt.userClass.getName),
+ messageParameters = Map("userClass" -> udt.userClass.getName),
cause = null)
}
def cannotFindEncoderForTypeError(typeName: String): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "ENCODER_NOT_FOUND",
- messageParameters = Map(
- "typeName" -> typeName,
- "docroot" -> SparkBuildInfo.spark_doc_root))
+ messageParameters = Map("typeName" -> typeName, "docroot" -> SparkBuildInfo.spark_doc_root))
}
def cannotHaveCircularReferencesInBeanClassError(
@@ -184,8 +182,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
def cannotFindConstructorForTypeError(tpe: String): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_2144",
- messageParameters = Map(
- "tpe" -> tpe))
+ messageParameters = Map("tpe" -> tpe))
}
def cannotHaveCircularReferencesInClassError(t: String): SparkUnsupportedOperationException = {
@@ -195,12 +192,12 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
}
def cannotUseInvalidJavaIdentifierAsFieldNameError(
- fieldName: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = {
+ fieldName: String,
+ walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_2140",
- messageParameters = Map(
- "fieldName" -> fieldName,
- "walkedTypePath" -> walkedTypePath.toString))
+ messageParameters =
+ Map("fieldName" -> fieldName, "walkedTypePath" -> walkedTypePath.toString))
}
def primaryConstructorNotFoundError(cls: Class[_]): SparkRuntimeException = {
@@ -213,8 +210,33 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase {
def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "_LEGACY_ERROR_TEMP_2154",
+ messageParameters = Map("innerCls" -> innerCls.getName))
+ }
+
+ def cannotUseKryoSerialization(): SparkRuntimeException = {
+ new SparkRuntimeException(errorClass = "CANNOT_USE_KRYO", messageParameters = Map.empty)
+ }
+
+ def notPublicClassError(name: String): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_2229",
+ messageParameters = Map("name" -> name))
+ }
+
+ def primitiveTypesNotSupportedError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException(errorClass = "_LEGACY_ERROR_TEMP_2230")
+ }
+
+ def elementsOfTupleExceedLimitError(): SparkUnsupportedOperationException = {
+ new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_2150")
+ }
+
+ def invalidAgnosticEncoderError(encoder: AnyRef): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "INVALID_AGNOSTIC_ENCODER",
messageParameters = Map(
- "innerCls" -> innerCls.getName))
+ "encoderType" -> encoder.getClass.getName,
+ "docroot" -> SparkBuildInfo.spark_doc_root))
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index e7ae9f2bfb7bb..b19607a28f06c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.trees.Origin
/**
- * Object for grouping all error messages of the query parsing.
- * Currently it includes all ParseException.
+ * Object for grouping all error messages of the query parsing. Currently it includes all
+ * ParseException.
*/
private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
@@ -37,9 +37,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def parserStackOverflow(parserRuleContext: ParserRuleContext): Throwable = {
- throw new ParseException(
- errorClass = "FAILED_TO_PARSE_TOO_COMPLEX",
- ctx = parserRuleContext)
+ throw new ParseException(errorClass = "FAILED_TO_PARSE_TOO_COMPLEX", ctx = parserRuleContext)
}
def insertOverwriteDirectoryUnsupportedError(): Throwable = {
@@ -160,7 +158,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def incompatibleJoinTypesError(
- joinType1: String, joinType2: String, ctx: ParserRuleContext): Throwable = {
+ joinType1: String,
+ joinType2: String,
+ ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "INCOMPATIBLE_JOIN_TYPES",
messageParameters = Map(
@@ -209,13 +209,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def cannotParseValueTypeError(
- valueType: String, value: String, ctx: TypeConstructorContext): Throwable = {
+ valueType: String,
+ value: String,
+ ctx: TypeConstructorContext): Throwable = {
new ParseException(
errorClass = "INVALID_TYPED_LITERAL",
- messageParameters = Map(
- "valueType" -> toSQLType(valueType),
- "value" -> toSQLValue(value)
- ),
+ messageParameters = Map("valueType" -> toSQLType(valueType), "value" -> toSQLValue(value)),
ctx)
}
@@ -231,8 +230,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
ctx)
}
- def invalidNumericLiteralRangeError(rawStrippedQualifier: String, minValue: BigDecimal,
- maxValue: BigDecimal, typeName: String, ctx: NumberContext): Throwable = {
+ def invalidNumericLiteralRangeError(
+ rawStrippedQualifier: String,
+ minValue: BigDecimal,
+ maxValue: BigDecimal,
+ typeName: String,
+ ctx: NumberContext): Throwable = {
new ParseException(
errorClass = "INVALID_NUMERIC_LITERAL_RANGE",
messageParameters = Map(
@@ -259,7 +262,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def fromToIntervalUnsupportedError(
- from: String, to: String, ctx: ParserRuleContext): Throwable = {
+ from: String,
+ to: String,
+ ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "_LEGACY_ERROR_TEMP_0028",
messageParameters = Map("from" -> from, "to" -> to),
@@ -288,7 +293,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def nestedTypeMissingElementTypeError(
- dataType: String, ctx: PrimitiveDataTypeContext): Throwable = {
+ dataType: String,
+ ctx: PrimitiveDataTypeContext): Throwable = {
dataType.toUpperCase(Locale.ROOT) match {
case "ARRAY" =>
new ParseException(
@@ -309,23 +315,25 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def partitionTransformNotExpectedError(
- name: String, expr: String, ctx: ApplyTransformContext): Throwable = {
+ name: String,
+ expr: String,
+ ctx: ApplyTransformContext): Throwable = {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX.INVALID_COLUMN_REFERENCE",
- messageParameters = Map(
- "transform" -> toSQLId(name),
- "expr" -> expr),
+ messageParameters = Map("transform" -> toSQLId(name), "expr" -> expr),
ctx)
}
def wrongNumberArgumentsForTransformError(
- name: String, actualNum: Int, ctx: ApplyTransformContext): Throwable = {
+ name: String,
+ actualNum: Int,
+ ctx: ApplyTransformContext): Throwable = {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX.TRANSFORM_WRONG_NUM_ARGS",
messageParameters = Map(
"transform" -> toSQLId(name),
- "expectedNum" -> "1",
- "actualNum" -> actualNum.toString),
+ "expectedNum" -> "1",
+ "actualNum" -> actualNum.toString),
ctx)
}
@@ -337,7 +345,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def cannotCleanReservedNamespacePropertyError(
- property: String, ctx: ParserRuleContext, msg: String): Throwable = {
+ property: String,
+ ctx: ParserRuleContext,
+ msg: String): Throwable = {
new ParseException(
errorClass = "UNSUPPORTED_FEATURE.SET_NAMESPACE_PROPERTY",
messageParameters = Map("property" -> property, "msg" -> msg),
@@ -348,12 +358,13 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
new ParseException(
errorClass = "UNSUPPORTED_FEATURE.SET_PROPERTIES_AND_DBPROPERTIES",
messageParameters = Map.empty,
- ctx
- )
+ ctx)
}
def cannotCleanReservedTablePropertyError(
- property: String, ctx: ParserRuleContext, msg: String): Throwable = {
+ property: String,
+ ctx: ParserRuleContext,
+ msg: String): Throwable = {
new ParseException(
errorClass = "UNSUPPORTED_FEATURE.SET_TABLE_PROPERTY",
messageParameters = Map("property" -> property, "msg" -> msg),
@@ -361,12 +372,12 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def duplicatedTablePathsFoundError(
- pathOne: String, pathTwo: String, ctx: ParserRuleContext): Throwable = {
+ pathOne: String,
+ pathTwo: String,
+ ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "_LEGACY_ERROR_TEMP_0032",
- messageParameters = Map(
- "pathOne" -> pathOne,
- "pathTwo" -> pathTwo),
+ messageParameters = Map("pathOne" -> pathOne, "pathTwo" -> pathTwo),
ctx)
}
@@ -374,15 +385,17 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0033", ctx)
}
- def operationInHiveStyleCommandUnsupportedError(operation: String,
- command: String, ctx: StatementContext, msgOpt: Option[String] = None): Throwable = {
+ def operationInHiveStyleCommandUnsupportedError(
+ operation: String,
+ command: String,
+ ctx: StatementContext,
+ msgOpt: Option[String] = None): Throwable = {
new ParseException(
errorClass = "_LEGACY_ERROR_TEMP_0034",
messageParameters = Map(
"operation" -> operation,
"command" -> command,
- "msg" -> msgOpt.map(m => s", $m").getOrElse("")
- ),
+ "msg" -> msgOpt.map(m => s", $m").getOrElse("")),
ctx)
}
@@ -415,7 +428,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def addCatalogInCacheTableAsSelectNotAllowedError(
- quoted: String, ctx: CacheTableContext): Throwable = {
+ quoted: String,
+ ctx: CacheTableContext): Throwable = {
new ParseException(
errorClass = "_LEGACY_ERROR_TEMP_0037",
messageParameters = Map("quoted" -> quoted),
@@ -479,22 +493,22 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def invalidPropertyKeyForSetQuotedConfigurationError(
- keyCandidate: String, valueStr: String, ctx: ParserRuleContext): Throwable = {
+ keyCandidate: String,
+ valueStr: String,
+ ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "INVALID_PROPERTY_KEY",
- messageParameters = Map(
- "key" -> toSQLConf(keyCandidate),
- "value" -> toSQLConf(valueStr)),
+ messageParameters = Map("key" -> toSQLConf(keyCandidate), "value" -> toSQLConf(valueStr)),
ctx)
}
def invalidPropertyValueForSetQuotedConfigurationError(
- valueCandidate: String, keyStr: String, ctx: ParserRuleContext): Throwable = {
+ valueCandidate: String,
+ keyStr: String,
+ ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "INVALID_PROPERTY_VALUE",
- messageParameters = Map(
- "value" -> toSQLConf(valueCandidate),
- "key" -> toSQLConf(keyStr)),
+ messageParameters = Map("value" -> toSQLConf(valueCandidate), "key" -> toSQLConf(keyStr)),
ctx)
}
@@ -542,12 +556,32 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
ctx)
}
+ def identityColumnUnsupportedDataType(
+ ctx: IdentityColumnContext,
+ dataType: String): Throwable = {
+ new ParseException("IDENTITY_COLUMNS_UNSUPPORTED_DATA_TYPE", Map("dataType" -> dataType), ctx)
+ }
+
+ def identityColumnIllegalStep(ctx: IdentityColSpecContext): Throwable = {
+ new ParseException("IDENTITY_COLUMNS_ILLEGAL_STEP", Map.empty, ctx)
+ }
+
+ def identityColumnDuplicatedSequenceGeneratorOption(
+ ctx: IdentityColSpecContext,
+ sequenceGeneratorOption: String): Throwable = {
+ new ParseException(
+ "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION",
+ Map("sequenceGeneratorOption" -> sequenceGeneratorOption),
+ ctx)
+ }
+
def createViewWithBothIfNotExistsAndReplaceError(ctx: CreateViewContext): Throwable = {
new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0052", ctx)
}
def temporaryViewWithSchemaBindingMode(ctx: StatementContext): Throwable = {
- new ParseException(errorClass = "UNSUPPORTED_FEATURE.TEMPORARY_VIEW_WITH_SCHEMA_BINDING_MODE",
+ new ParseException(
+ errorClass = "UNSUPPORTED_FEATURE.TEMPORARY_VIEW_WITH_SCHEMA_BINDING_MODE",
messageParameters = Map.empty,
ctx)
}
@@ -581,9 +615,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
def defineTempFuncWithIfNotExistsError(ctx: ParserRuleContext): Throwable = {
- new ParseException(
- errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS",
- ctx)
+ new ParseException(errorClass = "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_IF_NOT_EXISTS", ctx)
}
def unsupportedFunctionNameError(funcName: Seq[String], ctx: ParserRuleContext): Throwable = {
@@ -632,9 +664,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
def invalidNameForDropTempFunc(name: Seq[String], ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
- messageParameters = Map(
- "statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"),
- "funcName" -> toSQLId(name)),
+ messageParameters =
+ Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "funcName" -> toSQLId(name)),
ctx)
}
@@ -650,9 +681,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
new ParseException(errorClass = "REF_DEFAULT_VALUE_IS_NOT_ALLOWED_IN_PARTITION", ctx)
}
- def duplicateArgumentNamesError(
- arguments: Seq[String],
- ctx: ParserRuleContext): Throwable = {
+ def duplicateArgumentNamesError(arguments: Seq[String], ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "EXEC_IMMEDIATE_DUPLICATE_ARGUMENT_ALIASES",
messageParameters = Map("aliases" -> arguments.map(toSQLId).mkString(", ")),
@@ -679,12 +708,9 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
}
new ParseException(
errorClass = errorClass,
- messageParameters = alterTypeMap ++ Map(
- "columnName" -> columnName,
- "optionName" -> optionName
- ),
- ctx
- )
+ messageParameters =
+ alterTypeMap ++ Map("columnName" -> columnName, "optionName" -> optionName),
+ ctx)
}
def invalidDatetimeUnitError(
@@ -697,19 +723,17 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
"functionName" -> toSQLId(functionName),
"parameter" -> toSQLId("unit"),
"invalidValue" -> invalidValue),
- ctx
- )
+ ctx)
}
def invalidTableFunctionIdentifierArgumentMissingParentheses(
- ctx: ParserRuleContext, argumentName: String): Throwable = {
+ ctx: ParserRuleContext,
+ argumentName: String): Throwable = {
new ParseException(
errorClass =
"INVALID_SQL_SYNTAX.INVALID_TABLE_FUNCTION_IDENTIFIER_ARGUMENT_MISSING_PARENTHESES",
- messageParameters = Map(
- "argumentName" -> toSQLId(argumentName)),
- ctx
- )
+ messageParameters = Map("argumentName" -> toSQLId(argumentName)),
+ ctx)
}
def clusterByWithPartitionedBy(ctx: ParserRuleContext): Throwable = {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
index 146012b4266dd..7d8b33aa5e228 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala
@@ -88,8 +88,8 @@ object ProcessingTimeTrigger {
}
/**
- * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at
- * the specified interval.
+ * A [[Trigger]] that continuously processes streaming data, asynchronously checkpointing at the
+ * specified interval.
*/
case class ContinuousTrigger(intervalMs: Long) extends Trigger {
Triggers.validate(intervalMs)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 1a2fbdc1fd116..5bdaebe3b073a 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -43,9 +43,12 @@ import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefin
*
* Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
*
- * @tparam IN The input type for the aggregation.
- * @tparam BUF The type of the intermediate value of the reduction.
- * @tparam OUT The type of the final output result.
+ * @tparam IN
+ * The input type for the aggregation.
+ * @tparam BUF
+ * The type of the intermediate value of the reduction.
+ * @tparam OUT
+ * The type of the final output result.
* @since 1.6.0
*/
@SerialVersionUID(2093413866369130093L)
@@ -58,7 +61,7 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable with UserDefinedFu
def zero: BUF
/**
- * Combine two values to produce a new value. For performance, the function may modify `b` and
+ * Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
* @since 1.6.0
*/
@@ -93,8 +96,6 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable with UserDefinedFu
* @since 1.6.0
*/
def toColumn: TypedColumn[IN, OUT] = {
- new TypedColumn[IN, OUT](
- InvokeInlineUserDefinedFunction(this, Nil),
- outputEncoder)
+ new TypedColumn[IN, OUT](InvokeInlineUserDefinedFunction(this, Nil), outputEncoder)
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index c9e0e366a7447..6a22cbfaf351e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -52,8 +52,8 @@ sealed abstract class UserDefinedFunction extends UserDefinedFunctionLike {
def nullable: Boolean
/**
- * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same
- * input.
+ * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the
+ * same input.
*
* @since 2.3.0
*/
@@ -98,7 +98,8 @@ private[spark] case class SparkUserDefinedFunction(
outputEncoder: Option[Encoder[_]] = None,
givenName: Option[String] = None,
nullable: Boolean = true,
- deterministic: Boolean = true) extends UserDefinedFunction {
+ deterministic: Boolean = true)
+ extends UserDefinedFunction {
override def withName(name: String): SparkUserDefinedFunction = {
copy(givenName = Option(name))
@@ -169,7 +170,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
inputEncoder: Encoder[IN],
givenName: Option[String] = None,
nullable: Boolean = true,
- deterministic: Boolean = true) extends UserDefinedFunction {
+ deterministic: Boolean = true)
+ extends UserDefinedFunction {
override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
copy(givenName = Option(name))
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index 9c4499ee243f5..dbe2da8f97341 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -32,9 +32,10 @@ import org.apache.spark.sql.Column
* Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3)
* }}}
*
- * @note When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding,
- * unboundedFollowing) is used by default. When ordering is defined, a growing window frame
- * (rangeFrame, unboundedPreceding, currentRow) is used by default.
+ * @note
+ * When ordering is not defined, an unbounded window frame (rowFrame, unboundedPreceding,
+ * unboundedFollowing) is used by default. When ordering is defined, a growing window frame
+ * (rangeFrame, unboundedPreceding, currentRow) is used by default.
*
* @since 1.4.0
*/
@@ -47,7 +48,7 @@ object Window {
*/
@scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
- spec.partitionBy(colName, colNames : _*)
+ spec.partitionBy(colName, colNames: _*)
}
/**
@@ -56,7 +57,7 @@ object Window {
*/
@scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
- spec.partitionBy(cols : _*)
+ spec.partitionBy(cols: _*)
}
/**
@@ -65,7 +66,7 @@ object Window {
*/
@scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
- spec.orderBy(colName, colNames : _*)
+ spec.orderBy(colName, colNames: _*)
}
/**
@@ -74,12 +75,12 @@ object Window {
*/
@scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
- spec.orderBy(cols : _*)
+ spec.orderBy(cols: _*)
}
/**
- * Value representing the first row in the partition, equivalent to "UNBOUNDED PRECEDING" in SQL.
- * This can be used to specify the frame boundaries:
+ * Value representing the first row in the partition, equivalent to "UNBOUNDED PRECEDING" in
+ * SQL. This can be used to specify the frame boundaries:
*
* {{{
* Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
@@ -113,22 +114,22 @@ object Window {
def currentRow: Long = 0
/**
- * Creates a [[WindowSpec]] with the frame boundaries defined,
- * from `start` (inclusive) to `end` (inclusive).
+ * Creates a [[WindowSpec]] with the frame boundaries defined, from `start` (inclusive) to `end`
+ * (inclusive).
*
* Both `start` and `end` are positions relative to the current row. For example, "0" means
* "current row", while "-1" means the row before the current row, and "5" means the fifth row
* after the current row.
*
- * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`,
- * and `Window.currentRow` to specify special boundary values, rather than using integral
- * values directly.
+ * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and
+ * `Window.currentRow` to specify special boundary values, rather than using integral values
+ * directly.
*
- * A row based boundary is based on the position of the row within the partition.
- * An offset indicates the number of rows above or below the current row, the frame for the
- * current row starts or ends. For instance, given a row based sliding frame with a lower bound
- * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from
- * index 4 to index 7.
+ * A row based boundary is based on the position of the row within the partition. An offset
+ * indicates the number of rows above or below the current row, the frame for the current row
+ * starts or ends. For instance, given a row based sliding frame with a lower bound offset of -1
+ * and a upper bound offset of +2. The frame for row with index 5 would range from index 4 to
+ * index 7.
*
* {{{
* import org.apache.spark.sql.expressions.Window
@@ -150,10 +151,12 @@ object Window {
* +---+--------+---+
* }}}
*
- * @param start boundary start, inclusive. The frame is unbounded if this is
- * the minimum long value (`Window.unboundedPreceding`).
- * @param end boundary end, inclusive. The frame is unbounded if this is the
- * maximum long value (`Window.unboundedFollowing`).
+ * @param start
+ * boundary start, inclusive. The frame is unbounded if this is the minimum long value
+ * (`Window.unboundedPreceding`).
+ * @param end
+ * boundary end, inclusive. The frame is unbounded if this is the maximum long value
+ * (`Window.unboundedFollowing`).
* @since 2.1.0
*/
// Note: when updating the doc for this method, also update WindowSpec.rowsBetween.
@@ -162,25 +165,24 @@ object Window {
}
/**
- * Creates a [[WindowSpec]] with the frame boundaries defined,
- * from `start` (inclusive) to `end` (inclusive).
+ * Creates a [[WindowSpec]] with the frame boundaries defined, from `start` (inclusive) to `end`
+ * (inclusive).
*
* Both `start` and `end` are relative to the current row. For example, "0" means "current row",
- * while "-1" means one off before the current row, and "5" means the five off after the
- * current row.
+ * while "-1" means one off before the current row, and "5" means the five off after the current
+ * row.
*
- * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`,
- * and `Window.currentRow` to specify special boundary values, rather than using long values
+ * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and
+ * `Window.currentRow` to specify special boundary values, rather than using long values
* directly.
*
- * A range-based boundary is based on the actual value of the ORDER BY
- * expression(s). An offset is used to alter the value of the ORDER BY expression,
- * for instance if the current ORDER BY expression has a value of 10 and the lower bound offset
- * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a
- * number of constraints on the ORDER BY expressions: there can be only one expression and this
- * expression must have a numerical data type. An exception can be made when the offset is
- * unbounded, because no value modification is needed, in this case multiple and non-numeric
- * ORDER BY expression are allowed.
+ * A range-based boundary is based on the actual value of the ORDER BY expression(s). An offset
+ * is used to alter the value of the ORDER BY expression, for instance if the current ORDER BY
+ * expression has a value of 10 and the lower bound offset is -3, the resulting lower bound for
+ * the current row will be 10 - 3 = 7. This however puts a number of constraints on the ORDER BY
+ * expressions: there can be only one expression and this expression must have a numerical data
+ * type. An exception can be made when the offset is unbounded, because no value modification is
+ * needed, in this case multiple and non-numeric ORDER BY expression are allowed.
*
* {{{
* import org.apache.spark.sql.expressions.Window
@@ -202,10 +204,12 @@ object Window {
* +---+--------+---+
* }}}
*
- * @param start boundary start, inclusive. The frame is unbounded if this is
- * the minimum long value (`Window.unboundedPreceding`).
- * @param end boundary end, inclusive. The frame is unbounded if this is the
- * maximum long value (`Window.unboundedFollowing`).
+ * @param start
+ * boundary start, inclusive. The frame is unbounded if this is the minimum long value
+ * (`Window.unboundedPreceding`).
+ * @param end
+ * boundary end, inclusive. The frame is unbounded if this is the maximum long value
+ * (`Window.unboundedFollowing`).
* @since 2.1.0
*/
// Note: when updating the doc for this method, also update WindowSpec.rangeBetween.
@@ -234,4 +238,4 @@ object Window {
* @since 1.4.0
*/
@Stable
-class Window private() // So we can see Window in JavaDoc.
+class Window private () // So we can see Window in JavaDoc.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index a888563d66a71..9abdee9c79ebc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.internal.{ColumnNode, SortOrder, Window => EvalWindo
* @since 1.4.0
*/
@Stable
-class WindowSpec private[sql](
+class WindowSpec private[sql] (
partitionSpec: Seq[ColumnNode],
orderSpec: Seq[SortOrder],
frame: Option[WindowFrame]) {
@@ -78,15 +78,15 @@ class WindowSpec private[sql](
* "current row", while "-1" means the row before the current row, and "5" means the fifth row
* after the current row.
*
- * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`,
- * and `Window.currentRow` to specify special boundary values, rather than using integral
- * values directly.
+ * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and
+ * `Window.currentRow` to specify special boundary values, rather than using integral values
+ * directly.
*
- * A row based boundary is based on the position of the row within the partition.
- * An offset indicates the number of rows above or below the current row, the frame for the
- * current row starts or ends. For instance, given a row based sliding frame with a lower bound
- * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from
- * index 4 to index 7.
+ * A row based boundary is based on the position of the row within the partition. An offset
+ * indicates the number of rows above or below the current row, the frame for the current row
+ * starts or ends. For instance, given a row based sliding frame with a lower bound offset of -1
+ * and a upper bound offset of +2. The frame for row with index 5 would range from index 4 to
+ * index 7.
*
* {{{
* import org.apache.spark.sql.expressions.Window
@@ -108,10 +108,12 @@ class WindowSpec private[sql](
* +---+--------+---+
* }}}
*
- * @param start boundary start, inclusive. The frame is unbounded if this is
- * the minimum long value (`Window.unboundedPreceding`).
- * @param end boundary end, inclusive. The frame is unbounded if this is the
- * maximum long value (`Window.unboundedFollowing`).
+ * @param start
+ * boundary start, inclusive. The frame is unbounded if this is the minimum long value
+ * (`Window.unboundedPreceding`).
+ * @param end
+ * boundary end, inclusive. The frame is unbounded if this is the maximum long value
+ * (`Window.unboundedFollowing`).
* @since 1.4.0
*/
// Note: when updating the doc for this method, also update Window.rowsBetween.
@@ -136,22 +138,21 @@ class WindowSpec private[sql](
/**
* Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive).
*
- * Both `start` and `end` are relative from the current row. For example, "0" means
- * "current row", while "-1" means one off before the current row, and "5" means the five off
- * after the current row.
+ * Both `start` and `end` are relative from the current row. For example, "0" means "current
+ * row", while "-1" means one off before the current row, and "5" means the five off after the
+ * current row.
*
- * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`,
- * and `Window.currentRow` to specify special boundary values, rather than using long values
+ * We recommend users use `Window.unboundedPreceding`, `Window.unboundedFollowing`, and
+ * `Window.currentRow` to specify special boundary values, rather than using long values
* directly.
*
- * A range-based boundary is based on the actual value of the ORDER BY
- * expression(s). An offset is used to alter the value of the ORDER BY expression, for
- * instance if the current order by expression has a value of 10 and the lower bound offset
- * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a
- * number of constraints on the ORDER BY expressions: there can be only one expression and this
- * expression must have a numerical data type. An exception can be made when the offset is
- * unbounded, because no value modification is needed, in this case multiple and non-numeric
- * ORDER BY expression are allowed.
+ * A range-based boundary is based on the actual value of the ORDER BY expression(s). An offset
+ * is used to alter the value of the ORDER BY expression, for instance if the current order by
+ * expression has a value of 10 and the lower bound offset is -3, the resulting lower bound for
+ * the current row will be 10 - 3 = 7. This however puts a number of constraints on the ORDER BY
+ * expressions: there can be only one expression and this expression must have a numerical data
+ * type. An exception can be made when the offset is unbounded, because no value modification is
+ * needed, in this case multiple and non-numeric ORDER BY expression are allowed.
*
* {{{
* import org.apache.spark.sql.expressions.Window
@@ -173,10 +174,12 @@ class WindowSpec private[sql](
* +---+--------+---+
* }}}
*
- * @param start boundary start, inclusive. The frame is unbounded if this is
- * the minimum long value (`Window.unboundedPreceding`).
- * @param end boundary end, inclusive. The frame is unbounded if this is the
- * maximum long value (`Window.unboundedFollowing`).
+ * @param start
+ * boundary start, inclusive. The frame is unbounded if this is the minimum long value
+ * (`Window.unboundedPreceding`).
+ * @param end
+ * boundary end, inclusive. The frame is unbounded if this is the maximum long value
+ * (`Window.unboundedFollowing`).
* @since 1.4.0
*/
// Note: when updating the doc for this method, also update Window.rangeBetween.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index a4aa9c312aff2..5e7c993fae414 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -26,18 +26,21 @@ import org.apache.spark.sql.types._
* The base class for implementing user-defined aggregate functions (UDAF).
*
* @since 1.5.0
- * @deprecated UserDefinedAggregateFunction is deprecated.
- * Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method.
+ * @deprecated
+ * UserDefinedAggregateFunction is deprecated. Aggregator[IN, BUF, OUT] should now be registered
+ * as a UDF via the functions.udaf(agg) method.
*/
@Stable
-@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
- " via the functions.udaf(agg) method.", "3.0.0")
+@deprecated(
+ "Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
+ " via the functions.udaf(agg) method.",
+ "3.0.0")
abstract class UserDefinedAggregateFunction extends Serializable with UserDefinedFunctionLike {
/**
- * A `StructType` represents data types of input arguments of this aggregate function.
- * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
- * with type of `DoubleType` and `LongType`, the returned `StructType` will look like
+ * A `StructType` represents data types of input arguments of this aggregate function. For
+ * example, if a [[UserDefinedAggregateFunction]] expects two input arguments with type of
+ * `DoubleType` and `LongType`, the returned `StructType` will look like
*
* ```
* new StructType()
@@ -45,18 +48,17 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine
* .add("longInput", LongType)
* ```
*
- * The name of a field of this `StructType` is only used to identify the corresponding
- * input argument. Users can choose names to identify the input arguments.
+ * The name of a field of this `StructType` is only used to identify the corresponding input
+ * argument. Users can choose names to identify the input arguments.
*
* @since 1.5.0
*/
def inputSchema: StructType
/**
- * A `StructType` represents data types of values in the aggregation buffer.
- * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
- * (i.e. two intermediate values) with type of `DoubleType` and `LongType`,
- * the returned `StructType` will look like
+ * A `StructType` represents data types of values in the aggregation buffer. For example, if a
+ * [[UserDefinedAggregateFunction]]'s buffer has two values (i.e. two intermediate values) with
+ * type of `DoubleType` and `LongType`, the returned `StructType` will look like
*
* ```
* new StructType()
@@ -64,8 +66,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine
* .add("longInput", LongType)
* ```
*
- * The name of a field of this `StructType` is only used to identify the corresponding
- * buffer value. Users can choose names to identify the input arguments.
+ * The name of a field of this `StructType` is only used to identify the corresponding buffer
+ * value. Users can choose names to identify the input arguments.
*
* @since 1.5.0
*/
@@ -79,8 +81,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine
def dataType: DataType
/**
- * Returns true iff this function is deterministic, i.e. given the same input,
- * always return the same output.
+ * Returns true iff this function is deterministic, i.e. given the same input, always return the
+ * same output.
*
* @since 1.5.0
*/
@@ -90,8 +92,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine
* Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
*
* The contract should be that applying the merge function on two initial buffers should just
- * return the initial buffer itself, i.e.
- * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
+ * return the initial buffer itself, i.e. `merge(initialBuffer, initialBuffer)` should equal
+ * `initialBuffer`.
*
* @since 1.5.0
*/
@@ -134,8 +136,8 @@ abstract class UserDefinedAggregateFunction extends Serializable with UserDefine
}
/**
- * Creates a `Column` for this UDAF using the distinct values of the given
- * `Column`s as input arguments.
+ * Creates a `Column` for this UDAF using the distinct values of the given `Column`s as input
+ * arguments.
*
* @since 1.5.0
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index e46d6c95b31ae..0662b8f2b271f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -104,9 +104,9 @@ object functions {
/**
* Creates a [[Column]] of literal value.
*
- * The passed in object is returned directly if it is already a [[Column]].
- * If the object is a Scala Symbol, it is converted into a [[Column]] also.
- * Otherwise, a new [[Column]] is created to represent the literal value.
+ * The passed in object is returned directly if it is already a [[Column]]. If the object is a
+ * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created
+ * to represent the literal value.
*
* @group normal_funcs
* @since 1.3.0
@@ -121,7 +121,7 @@ object functions {
// method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence,
// we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code.
// This is significantly better when there are many threads calling `lit` concurrently.
- Column(internal.Literal(literal))
+ Column(internal.Literal(literal))
}
}
@@ -133,26 +133,26 @@ object functions {
* @group normal_funcs
* @since 2.2.0
*/
- def typedLit[T : TypeTag](literal: T): Column = {
+ def typedLit[T: TypeTag](literal: T): Column = {
typedlit(literal)
}
/**
* Creates a [[Column]] of literal value.
*
- * The passed in object is returned directly if it is already a [[Column]].
- * If the object is a Scala Symbol, it is converted into a [[Column]] also.
- * Otherwise, a new [[Column]] is created to represent the literal value.
- * The difference between this function and [[lit]] is that this function
- * can handle parameterized scala types e.g.: List, Seq and Map.
+ * The passed in object is returned directly if it is already a [[Column]]. If the object is a
+ * Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created
+ * to represent the literal value. The difference between this function and [[lit]] is that this
+ * function can handle parameterized scala types e.g.: List, Seq and Map.
*
- * @note `typedlit` will call expensive Scala reflection APIs. `lit` is preferred if parameterized
- * Scala types are not used.
+ * @note
+ * `typedlit` will call expensive Scala reflection APIs. `lit` is preferred if parameterized
+ * Scala types are not used.
*
* @group normal_funcs
* @since 3.2.0
*/
- def typedlit[T : TypeTag](literal: T): Column = {
+ def typedlit[T: TypeTag](literal: T): Column = {
literal match {
case c: Column => c
case s: Symbol => new ColumnName(s.name)
@@ -178,8 +178,8 @@ object functions {
def asc(columnName: String): Column = Column(columnName).asc
/**
- * Returns a sort expression based on ascending order of the column,
- * and null values return before non-null values.
+ * Returns a sort expression based on ascending order of the column, and null values return
+ * before non-null values.
* {{{
* df.sort(asc_nulls_first("dept"), desc("age"))
* }}}
@@ -190,8 +190,8 @@ object functions {
def asc_nulls_first(columnName: String): Column = Column(columnName).asc_nulls_first
/**
- * Returns a sort expression based on ascending order of the column,
- * and null values appear after non-null values.
+ * Returns a sort expression based on ascending order of the column, and null values appear
+ * after non-null values.
* {{{
* df.sort(asc_nulls_last("dept"), desc("age"))
* }}}
@@ -213,8 +213,8 @@ object functions {
def desc(columnName: String): Column = Column(columnName).desc
/**
- * Returns a sort expression based on the descending order of the column,
- * and null values appear before non-null values.
+ * Returns a sort expression based on the descending order of the column, and null values appear
+ * before non-null values.
* {{{
* df.sort(asc("dept"), desc_nulls_first("age"))
* }}}
@@ -225,8 +225,8 @@ object functions {
def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first
/**
- * Returns a sort expression based on the descending order of the column,
- * and null values appear after non-null values.
+ * Returns a sort expression based on the descending order of the column, and null values appear
+ * after non-null values.
* {{{
* df.sort(asc("dept"), desc_nulls_last("age"))
* }}}
@@ -236,7 +236,6 @@ object functions {
*/
def desc_nulls_last(columnName: String): Column = Column(columnName).desc_nulls_last
-
//////////////////////////////////////////////////////////////////////////////////////////////
// Aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -285,12 +284,14 @@ object functions {
* @group agg_funcs
* @since 2.1.0
*/
- def approx_count_distinct(columnName: String): Column = approx_count_distinct(column(columnName))
+ def approx_count_distinct(columnName: String): Column = approx_count_distinct(
+ column(columnName))
/**
* Aggregate function: returns the approximate number of distinct items in a group.
*
- * @param rsd maximum relative standard deviation allowed (default = 0.05)
+ * @param rsd
+ * maximum relative standard deviation allowed (default = 0.05)
*
* @group agg_funcs
* @since 2.1.0
@@ -302,7 +303,8 @@ object functions {
/**
* Aggregate function: returns the approximate number of distinct items in a group.
*
- * @param rsd maximum relative standard deviation allowed (default = 0.05)
+ * @param rsd
+ * maximum relative standard deviation allowed (default = 0.05)
*
* @group agg_funcs
* @since 2.1.0
@@ -330,8 +332,9 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * @note The function is non-deterministic because the order of collected results depends
- * on the order of the rows which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because the order of collected results depends on the
+ * order of the rows which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.6.0
@@ -341,8 +344,9 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * @note The function is non-deterministic because the order of collected results depends
- * on the order of the rows which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because the order of collected results depends on the
+ * order of the rows which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.6.0
@@ -352,8 +356,9 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
- * @note The function is non-deterministic because the order of collected results depends
- * on the order of the rows which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because the order of collected results depends on the
+ * order of the rows which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.6.0
@@ -363,8 +368,9 @@ object functions {
/**
* Aggregate function: returns a set of objects with duplicate elements eliminated.
*
- * @note The function is non-deterministic because the order of collected results depends
- * on the order of the rows which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because the order of collected results depends on the
+ * order of the rows which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.6.0
@@ -372,10 +378,10 @@ object functions {
def collect_set(columnName: String): Column = collect_set(Column(columnName))
/**
- * Returns a count-min sketch of a column with the given esp, confidence and seed. The result
- * is an array of bytes, which can be deserialized to a `CountMinSketch` before usage.
- * Count-min sketch is a probabilistic data structure used for cardinality estimation using
- * sub-linear space.
+ * Returns a count-min sketch of a column with the given esp, confidence and seed. The result is
+ * an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min
+ * sketch is a probabilistic data structure used for cardinality estimation using sub-linear
+ * space.
*
* @group agg_funcs
* @since 3.5.0
@@ -383,6 +389,18 @@ object functions {
def count_min_sketch(e: Column, eps: Column, confidence: Column, seed: Column): Column =
Column.fn("count_min_sketch", e, eps, confidence, seed)
+ /**
+ * Returns a count-min sketch of a column with the given esp, confidence and seed. The result is
+ * an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min
+ * sketch is a probabilistic data structure used for cardinality estimation using sub-linear
+ * space.
+ *
+ * @group agg_funcs
+ * @since 4.0.0
+ */
+ def count_min_sketch(e: Column, eps: Column, confidence: Column): Column =
+ count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt))
+
private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
@@ -443,7 +461,7 @@ object functions {
*/
@scala.annotation.varargs
def countDistinct(columnName: String, columnNames: String*): Column =
- count_distinct(Column(columnName), columnNames.map(Column.apply) : _*)
+ count_distinct(Column(columnName), columnNames.map(Column.apply): _*)
/**
* Aggregate function: returns the number of distinct items in a group.
@@ -499,8 +517,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 2.0.0
@@ -514,8 +533,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 2.0.0
@@ -530,8 +550,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.3.0
@@ -544,8 +565,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.3.0
@@ -555,8 +577,9 @@ object functions {
/**
* Aggregate function: returns the first value in a group.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 3.5.0
@@ -569,8 +592,9 @@ object functions {
* The function by default returns the first values it sees. It will return the first non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 3.5.0
@@ -579,8 +603,8 @@ object functions {
Column.fn("first_value", e, ignoreNulls)
/**
- * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
- * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+ * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or
+ * not, returns 1 for aggregated or 0 for not aggregated in the result set.
*
* @group agg_funcs
* @since 2.0.0
@@ -588,8 +612,8 @@ object functions {
def grouping(e: Column): Column = Column.fn("grouping", e)
/**
- * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
- * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+ * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated or
+ * not, returns 1 for aggregated or 0 for not aggregated in the result set.
*
* @group agg_funcs
* @since 2.0.0
@@ -603,8 +627,9 @@ object functions {
* (grouping(c1) <<; (n-1)) + (grouping(c2) <<; (n-2)) + ... + grouping(cn)
* }}}
*
- * @note The list of columns should match with grouping columns exactly, or empty (means all the
- * grouping columns).
+ * @note
+ * The list of columns should match with grouping columns exactly, or empty (means all the
+ * grouping columns).
*
* @group agg_funcs
* @since 2.0.0
@@ -618,18 +643,19 @@ object functions {
* (grouping(c1) <<; (n-1)) + (grouping(c2) <<; (n-2)) + ... + grouping(cn)
* }}}
*
- * @note The list of columns should match with grouping columns exactly.
+ * @note
+ * The list of columns should match with grouping columns exactly.
*
* @group agg_funcs
* @since 2.0.0
*/
def grouping_id(colName: String, colNames: String*): Column = {
- grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*)
+ grouping_id((Seq(colName) ++ colNames).map(n => Column(n)): _*)
}
/**
- * Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch configured with lgConfigK arg.
+ * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch
+ * configured with lgConfigK arg.
*
* @group agg_funcs
* @since 3.5.0
@@ -638,8 +664,8 @@ object functions {
Column.fn("hll_sketch_agg", e, lgConfigK)
/**
- * Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch configured with lgConfigK arg.
+ * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch
+ * configured with lgConfigK arg.
*
* @group agg_funcs
* @since 3.5.0
@@ -648,8 +674,8 @@ object functions {
Column.fn("hll_sketch_agg", e, lit(lgConfigK))
/**
- * Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch configured with lgConfigK arg.
+ * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch
+ * configured with lgConfigK arg.
*
* @group agg_funcs
* @since 3.5.0
@@ -659,8 +685,8 @@ object functions {
}
/**
- * Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch configured with default lgConfigK value.
+ * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch
+ * configured with default lgConfigK value.
*
* @group agg_funcs
* @since 3.5.0
@@ -669,8 +695,8 @@ object functions {
Column.fn("hll_sketch_agg", e)
/**
- * Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch configured with default lgConfigK value.
+ * Aggregate function: returns the updatable binary representation of the Datasketches HllSketch
+ * configured with default lgConfigK value.
*
* @group agg_funcs
* @since 3.5.0
@@ -681,9 +707,9 @@ object functions {
/**
* Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch, generated by merging previously created Datasketches HllSketch instances
- * via a Datasketches Union instance. Throws an exception if sketches have different
- * lgConfigK values and allowDifferentLgConfigK is set to false.
+ * HllSketch, generated by merging previously created Datasketches HllSketch instances via a
+ * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values
+ * and allowDifferentLgConfigK is set to false.
*
* @group agg_funcs
* @since 3.5.0
@@ -693,9 +719,9 @@ object functions {
/**
* Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch, generated by merging previously created Datasketches HllSketch instances
- * via a Datasketches Union instance. Throws an exception if sketches have different
- * lgConfigK values and allowDifferentLgConfigK is set to false.
+ * HllSketch, generated by merging previously created Datasketches HllSketch instances via a
+ * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values
+ * and allowDifferentLgConfigK is set to false.
*
* @group agg_funcs
* @since 3.5.0
@@ -705,9 +731,9 @@ object functions {
/**
* Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch, generated by merging previously created Datasketches HllSketch instances
- * via a Datasketches Union instance. Throws an exception if sketches have different
- * lgConfigK values and allowDifferentLgConfigK is set to false.
+ * HllSketch, generated by merging previously created Datasketches HllSketch instances via a
+ * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values
+ * and allowDifferentLgConfigK is set to false.
*
* @group agg_funcs
* @since 3.5.0
@@ -718,9 +744,8 @@ object functions {
/**
* Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch, generated by merging previously created Datasketches HllSketch instances
- * via a Datasketches Union instance. Throws an exception if sketches have different
- * lgConfigK values.
+ * HllSketch, generated by merging previously created Datasketches HllSketch instances via a
+ * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values.
*
* @group agg_funcs
* @since 3.5.0
@@ -730,9 +755,8 @@ object functions {
/**
* Aggregate function: returns the updatable binary representation of the Datasketches
- * HllSketch, generated by merging previously created Datasketches HllSketch instances
- * via a Datasketches Union instance. Throws an exception if sketches have different
- * lgConfigK values.
+ * HllSketch, generated by merging previously created Datasketches HllSketch instances via a
+ * Datasketches Union instance. Throws an exception if sketches have different lgConfigK values.
*
* @group agg_funcs
* @since 3.5.0
@@ -763,8 +787,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 2.0.0
@@ -778,8 +803,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 2.0.0
@@ -794,8 +820,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.3.0
@@ -808,8 +835,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 1.3.0
@@ -819,8 +847,9 @@ object functions {
/**
* Aggregate function: returns the last value in a group.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 3.5.0
@@ -833,8 +862,9 @@ object functions {
* The function by default returns the last values it sees. It will return the last non-null
* value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
*
- * @note The function is non-deterministic because its results depends on the order of the rows
- * which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because its results depends on the order of the rows
+ * which may be non-deterministic after a shuffle.
*
* @group agg_funcs
* @since 3.5.0
@@ -881,8 +911,9 @@ object functions {
/**
* Aggregate function: returns the value associated with the maximum value of ord.
*
- * @note The function is non-deterministic so the output order can be different for
- * those associated the same values of `e`.
+ * @note
+ * The function is non-deterministic so the output order can be different for those associated
+ * the same values of `e`.
*
* @group agg_funcs
* @since 3.3.0
@@ -890,8 +921,7 @@ object functions {
def max_by(e: Column, ord: Column): Column = Column.fn("max_by", e, ord)
/**
- * Aggregate function: returns the average of the values in a group.
- * Alias for avg.
+ * Aggregate function: returns the average of the values in a group. Alias for avg.
*
* @group agg_funcs
* @since 1.4.0
@@ -899,8 +929,7 @@ object functions {
def mean(e: Column): Column = avg(e)
/**
- * Aggregate function: returns the average of the values in a group.
- * Alias for avg.
+ * Aggregate function: returns the average of the values in a group. Alias for avg.
*
* @group agg_funcs
* @since 1.4.0
@@ -934,8 +963,9 @@ object functions {
/**
* Aggregate function: returns the value associated with the minimum value of ord.
*
- * @note The function is non-deterministic so the output order can be different for
- * those associated the same values of `e`.
+ * @note
+ * The function is non-deterministic so the output order can be different for those associated
+ * the same values of `e`.
*
* @group agg_funcs
* @since 3.3.0
@@ -943,8 +973,8 @@ object functions {
def min_by(e: Column, ord: Column): Column = Column.fn("min_by", e, ord)
/**
- * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the
- * given percentage(s) with value range in [0.0, 1.0].
+ * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the given
+ * percentage(s) with value range in [0.0, 1.0].
*
* @group agg_funcs
* @since 3.5.0
@@ -952,8 +982,8 @@ object functions {
def percentile(e: Column, percentage: Column): Column = Column.fn("percentile", e, percentage)
/**
- * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the
- * given percentage(s) with value range in [0.0, 1.0].
+ * Aggregate function: returns the exact percentile(s) of numeric column `expr` at the given
+ * percentage(s) with value range in [0.0, 1.0].
*
* @group agg_funcs
* @since 3.5.0
@@ -962,17 +992,16 @@ object functions {
Column.fn("percentile", e, percentage, frequency)
/**
- * Aggregate function: returns the approximate `percentile` of the numeric column `col` which
- * is the smallest value in the ordered `col` values (sorted from least to greatest) such that
- * no more than `percentage` of `col` values is less than the value or equal to that value.
+ * Aggregate function: returns the approximate `percentile` of the numeric column `col` which is
+ * the smallest value in the ordered `col` values (sorted from least to greatest) such that no
+ * more than `percentage` of `col` values is less than the value or equal to that value.
*
- * If percentage is an array, each value must be between 0.0 and 1.0.
- * If it is a single floating point value, it must be between 0.0 and 1.0.
+ * If percentage is an array, each value must be between 0.0 and 1.0. If it is a single floating
+ * point value, it must be between 0.0 and 1.0.
*
- * The accuracy parameter is a positive numeric literal
- * which controls approximation accuracy at the cost of memory.
- * Higher value of accuracy yields better accuracy, 1.0/accuracy
- * is the relative error of the approximation.
+ * The accuracy parameter is a positive numeric literal which controls approximation accuracy at
+ * the cost of memory. Higher value of accuracy yields better accuracy, 1.0/accuracy is the
+ * relative error of the approximation.
*
* @group agg_funcs
* @since 3.1.0
@@ -981,17 +1010,16 @@ object functions {
Column.fn("percentile_approx", e, percentage, accuracy)
/**
- * Aggregate function: returns the approximate `percentile` of the numeric column `col` which
- * is the smallest value in the ordered `col` values (sorted from least to greatest) such that
- * no more than `percentage` of `col` values is less than the value or equal to that value.
+ * Aggregate function: returns the approximate `percentile` of the numeric column `col` which is
+ * the smallest value in the ordered `col` values (sorted from least to greatest) such that no
+ * more than `percentage` of `col` values is less than the value or equal to that value.
*
- * If percentage is an array, each value must be between 0.0 and 1.0.
- * If it is a single floating point value, it must be between 0.0 and 1.0.
+ * If percentage is an array, each value must be between 0.0 and 1.0. If it is a single floating
+ * point value, it must be between 0.0 and 1.0.
*
- * The accuracy parameter is a positive numeric literal
- * which controls approximation accuracy at the cost of memory.
- * Higher value of accuracy yields better accuracy, 1.0/accuracy
- * is the relative error of the approximation.
+ * The accuracy parameter is a positive numeric literal which controls approximation accuracy at
+ * the cost of memory. Higher value of accuracy yields better accuracy, 1.0/accuracy is the
+ * relative error of the approximation.
*
* @group agg_funcs
* @since 3.5.0
@@ -1049,8 +1077,7 @@ object functions {
def stddev(columnName: String): Column = stddev(Column(columnName))
/**
- * Aggregate function: returns the sample standard deviation of
- * the expression in a group.
+ * Aggregate function: returns the sample standard deviation of the expression in a group.
*
* @group agg_funcs
* @since 1.6.0
@@ -1058,8 +1085,7 @@ object functions {
def stddev_samp(e: Column): Column = Column.fn("stddev_samp", e)
/**
- * Aggregate function: returns the sample standard deviation of
- * the expression in a group.
+ * Aggregate function: returns the sample standard deviation of the expression in a group.
*
* @group agg_funcs
* @since 1.6.0
@@ -1067,8 +1093,7 @@ object functions {
def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName))
/**
- * Aggregate function: returns the population standard deviation of
- * the expression in a group.
+ * Aggregate function: returns the population standard deviation of the expression in a group.
*
* @group agg_funcs
* @since 1.6.0
@@ -1076,8 +1101,7 @@ object functions {
def stddev_pop(e: Column): Column = Column.fn("stddev_pop", e)
/**
- * Aggregate function: returns the population standard deviation of
- * the expression in a group.
+ * Aggregate function: returns the population standard deviation of the expression in a group.
*
* @group agg_funcs
* @since 1.6.0
@@ -1175,8 +1199,8 @@ object functions {
def var_pop(columnName: String): Column = var_pop(Column(columnName))
/**
- * Aggregate function: returns the average of the independent variable for non-null pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns the average of the independent variable for non-null pairs in a
+ * group, where `y` is the dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1184,8 +1208,8 @@ object functions {
def regr_avgx(y: Column, x: Column): Column = Column.fn("regr_avgx", y, x)
/**
- * Aggregate function: returns the average of the independent variable for non-null pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns the average of the independent variable for non-null pairs in a
+ * group, where `y` is the dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1193,8 +1217,8 @@ object functions {
def regr_avgy(y: Column, x: Column): Column = Column.fn("regr_avgy", y, x)
/**
- * Aggregate function: returns the number of non-null number pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns the number of non-null number pairs in a group, where `y` is the
+ * dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1202,9 +1226,9 @@ object functions {
def regr_count(y: Column, x: Column): Column = Column.fn("regr_count", y, x)
/**
- * Aggregate function: returns the intercept of the univariate linear regression line
- * for non-null pairs in a group, where `y` is the dependent variable and
- * `x` is the independent variable.
+ * Aggregate function: returns the intercept of the univariate linear regression line for
+ * non-null pairs in a group, where `y` is the dependent variable and `x` is the independent
+ * variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1212,8 +1236,8 @@ object functions {
def regr_intercept(y: Column, x: Column): Column = Column.fn("regr_intercept", y, x)
/**
- * Aggregate function: returns the coefficient of determination for non-null pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns the coefficient of determination for non-null pairs in a group,
+ * where `y` is the dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1221,8 +1245,8 @@ object functions {
def regr_r2(y: Column, x: Column): Column = Column.fn("regr_r2", y, x)
/**
- * Aggregate function: returns the slope of the linear regression line for non-null pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns the slope of the linear regression line for non-null pairs in a
+ * group, where `y` is the dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1230,8 +1254,8 @@ object functions {
def regr_slope(y: Column, x: Column): Column = Column.fn("regr_slope", y, x)
/**
- * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs in a group,
+ * where `y` is the dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1239,8 +1263,8 @@ object functions {
def regr_sxx(y: Column, x: Column): Column = Column.fn("regr_sxx", y, x)
/**
- * Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns REGR_COUNT(y, x) * COVAR_POP(y, x) for non-null pairs in a group,
+ * where `y` is the dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1248,8 +1272,8 @@ object functions {
def regr_sxy(y: Column, x: Column): Column = Column.fn("regr_sxy", y, x)
/**
- * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs
- * in a group, where `y` is the dependent variable and `x` is the independent variable.
+ * Aggregate function: returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs in a group,
+ * where `y` is the dependent variable and `x` is the independent variable.
*
* @group agg_funcs
* @since 3.5.0
@@ -1265,8 +1289,8 @@ object functions {
def any_value(e: Column): Column = Column.fn("any_value", e)
/**
- * Aggregate function: returns some value of `e` for a group of rows.
- * If `isIgnoreNull` is true, returns only non-null values.
+ * Aggregate function: returns some value of `e` for a group of rows. If `isIgnoreNull` is true,
+ * returns only non-null values.
*
* @group agg_funcs
* @since 3.5.0
@@ -1283,13 +1307,12 @@ object functions {
def count_if(e: Column): Column = Column.fn("count_if", e)
/**
- * Aggregate function: computes a histogram on numeric 'expr' using nb bins.
- * The return value is an array of (x,y) pairs representing the centers of the
- * histogram's bins. As the value of 'nb' is increased, the histogram approximation
- * gets finer-grained, but may yield artifacts around outliers. In practice, 20-40
- * histogram bins appear to work well, with more bins being required for skewed or
- * smaller datasets. Note that this function creates a histogram with non-uniform
- * bin widths. It offers no guarantees in terms of the mean-squared-error of the
+ * Aggregate function: computes a histogram on numeric 'expr' using nb bins. The return value is
+ * an array of (x,y) pairs representing the centers of the histogram's bins. As the value of
+ * 'nb' is increased, the histogram approximation gets finer-grained, but may yield artifacts
+ * around outliers. In practice, 20-40 histogram bins appear to work well, with more bins being
+ * required for skewed or smaller datasets. Note that this function creates a histogram with
+ * non-uniform bin widths. It offers no guarantees in terms of the mean-squared-error of the
* histogram, but in practice is comparable to the histograms produced by the R/S-Plus
* statistical computing packages. Note: the output type of the 'x' field in the return value is
* propagated from the input value consumed in the aggregate function.
@@ -1386,10 +1409,10 @@ object functions {
* Window function: returns the rank of rows within a window partition, without any gaps.
*
* The difference between rank and dense_rank is that denseRank leaves no gaps in ranking
- * sequence when there are ties. That is, if you were ranking a competition using dense_rank
- * and had three people tie for second place, you would say that all three were in second
- * place and that the next person came in third. Rank would give me sequential numbers, making
- * the person that came in third place (after the ties) would register as coming in fifth.
+ * sequence when there are ties. That is, if you were ranking a competition using dense_rank and
+ * had three people tie for second place, you would say that all three were in second place and
+ * that the next person came in third. Rank would give me sequential numbers, making the person
+ * that came in third place (after the ties) would register as coming in fifth.
*
* This is equivalent to the DENSE_RANK function in SQL.
*
@@ -1399,9 +1422,9 @@ object functions {
def dense_rank(): Column = Column.fn("dense_rank")
/**
- * Window function: returns the value that is `offset` rows before the current row, and
- * `null` if there is less than `offset` rows before the current row. For example,
- * an `offset` of one will return the previous row at any given point in the window partition.
+ * Window function: returns the value that is `offset` rows before the current row, and `null`
+ * if there is less than `offset` rows before the current row. For example, an `offset` of one
+ * will return the previous row at any given point in the window partition.
*
* This is equivalent to the LAG function in SQL.
*
@@ -1411,9 +1434,9 @@ object functions {
def lag(e: Column, offset: Int): Column = lag(e, offset, null)
/**
- * Window function: returns the value that is `offset` rows before the current row, and
- * `null` if there is less than `offset` rows before the current row. For example,
- * an `offset` of one will return the previous row at any given point in the window partition.
+ * Window function: returns the value that is `offset` rows before the current row, and `null`
+ * if there is less than `offset` rows before the current row. For example, an `offset` of one
+ * will return the previous row at any given point in the window partition.
*
* This is equivalent to the LAG function in SQL.
*
@@ -1424,8 +1447,8 @@ object functions {
/**
* Window function: returns the value that is `offset` rows before the current row, and
- * `defaultValue` if there is less than `offset` rows before the current row. For example,
- * an `offset` of one will return the previous row at any given point in the window partition.
+ * `defaultValue` if there is less than `offset` rows before the current row. For example, an
+ * `offset` of one will return the previous row at any given point in the window partition.
*
* This is equivalent to the LAG function in SQL.
*
@@ -1438,8 +1461,8 @@ object functions {
/**
* Window function: returns the value that is `offset` rows before the current row, and
- * `defaultValue` if there is less than `offset` rows before the current row. For example,
- * an `offset` of one will return the previous row at any given point in the window partition.
+ * `defaultValue` if there is less than `offset` rows before the current row. For example, an
+ * `offset` of one will return the previous row at any given point in the window partition.
*
* This is equivalent to the LAG function in SQL.
*
@@ -1453,9 +1476,9 @@ object functions {
/**
* Window function: returns the value that is `offset` rows before the current row, and
* `defaultValue` if there is less than `offset` rows before the current row. `ignoreNulls`
- * determines whether null values of row are included in or eliminated from the calculation.
- * For example, an `offset` of one will return the previous row at any given point in the
- * window partition.
+ * determines whether null values of row are included in or eliminated from the calculation. For
+ * example, an `offset` of one will return the previous row at any given point in the window
+ * partition.
*
* This is equivalent to the LAG function in SQL.
*
@@ -1466,9 +1489,9 @@ object functions {
Column.fn("lag", false, e, lit(offset), lit(defaultValue), lit(ignoreNulls))
/**
- * Window function: returns the value that is `offset` rows after the current row, and
- * `null` if there is less than `offset` rows after the current row. For example,
- * an `offset` of one will return the next row at any given point in the window partition.
+ * Window function: returns the value that is `offset` rows after the current row, and `null` if
+ * there is less than `offset` rows after the current row. For example, an `offset` of one will
+ * return the next row at any given point in the window partition.
*
* This is equivalent to the LEAD function in SQL.
*
@@ -1478,9 +1501,9 @@ object functions {
def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) }
/**
- * Window function: returns the value that is `offset` rows after the current row, and
- * `null` if there is less than `offset` rows after the current row. For example,
- * an `offset` of one will return the next row at any given point in the window partition.
+ * Window function: returns the value that is `offset` rows after the current row, and `null` if
+ * there is less than `offset` rows after the current row. For example, an `offset` of one will
+ * return the next row at any given point in the window partition.
*
* This is equivalent to the LEAD function in SQL.
*
@@ -1491,8 +1514,8 @@ object functions {
/**
* Window function: returns the value that is `offset` rows after the current row, and
- * `defaultValue` if there is less than `offset` rows after the current row. For example,
- * an `offset` of one will return the next row at any given point in the window partition.
+ * `defaultValue` if there is less than `offset` rows after the current row. For example, an
+ * `offset` of one will return the next row at any given point in the window partition.
*
* This is equivalent to the LEAD function in SQL.
*
@@ -1505,8 +1528,8 @@ object functions {
/**
* Window function: returns the value that is `offset` rows after the current row, and
- * `defaultValue` if there is less than `offset` rows after the current row. For example,
- * an `offset` of one will return the next row at any given point in the window partition.
+ * `defaultValue` if there is less than `offset` rows after the current row. For example, an
+ * `offset` of one will return the next row at any given point in the window partition.
*
* This is equivalent to the LEAD function in SQL.
*
@@ -1520,9 +1543,9 @@ object functions {
/**
* Window function: returns the value that is `offset` rows after the current row, and
* `defaultValue` if there is less than `offset` rows after the current row. `ignoreNulls`
- * determines whether null values of row are included in or eliminated from the calculation.
- * The default value of `ignoreNulls` is false. For example, an `offset` of one will return
- * the next row at any given point in the window partition.
+ * determines whether null values of row are included in or eliminated from the calculation. The
+ * default value of `ignoreNulls` is false. For example, an `offset` of one will return the next
+ * row at any given point in the window partition.
*
* This is equivalent to the LEAD function in SQL.
*
@@ -1533,11 +1556,11 @@ object functions {
Column.fn("lead", false, e, lit(offset), lit(defaultValue), lit(ignoreNulls))
/**
- * Window function: returns the value that is the `offset`th row of the window frame
- * (counting from 1), and `null` if the size of window frame is less than `offset` rows.
+ * Window function: returns the value that is the `offset`th row of the window frame (counting
+ * from 1), and `null` if the size of window frame is less than `offset` rows.
*
- * It will return the `offset`th non-null value it sees when ignoreNulls is set to true.
- * If all values are null, then null is returned.
+ * It will return the `offset`th non-null value it sees when ignoreNulls is set to true. If all
+ * values are null, then null is returned.
*
* This is equivalent to the nth_value function in SQL.
*
@@ -1548,8 +1571,8 @@ object functions {
Column.fn("nth_value", false, e, lit(offset), lit(ignoreNulls))
/**
- * Window function: returns the value that is the `offset`th row of the window frame
- * (counting from 1), and `null` if the size of window frame is less than `offset` rows.
+ * Window function: returns the value that is the `offset`th row of the window frame (counting
+ * from 1), and `null` if the size of window frame is less than `offset` rows.
*
* This is equivalent to the nth_value function in SQL.
*
@@ -1560,8 +1583,8 @@ object functions {
/**
* Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window
- * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the second
- * quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
+ * partition. For example, if `n` is 4, the first quarter of the rows will get value 1, the
+ * second quarter will get 2, the third quarter will get 3, and the last quarter will get 4.
*
* This is equivalent to the NTILE function in SQL.
*
@@ -1571,7 +1594,8 @@ object functions {
def ntile(n: Int): Column = Column.fn("ntile", lit(n))
/**
- * Window function: returns the relative rank (i.e. percentile) of rows within a window partition.
+ * Window function: returns the relative rank (i.e. percentile) of rows within a window
+ * partition.
*
* This is computed by:
* {{{
@@ -1589,10 +1613,10 @@ object functions {
* Window function: returns the rank of rows within a window partition.
*
* The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking
- * sequence when there are ties. That is, if you were ranking a competition using dense_rank
- * and had three people tie for second place, you would say that all three were in second
- * place and that the next person came in third. Rank would give me sequential numbers, making
- * the person that came in third place (after the ties) would register as coming in fifth.
+ * sequence when there are ties. That is, if you were ranking a competition using dense_rank and
+ * had three people tie for second place, you would say that all three were in second place and
+ * that the next person came in third. Rank would give me sequential numbers, making the person
+ * that came in third place (after the ties) would register as coming in fifth.
*
* This is equivalent to the RANK function in SQL.
*
@@ -1630,13 +1654,13 @@ object functions {
*/
@scala.annotation.varargs
def array(colName: String, colNames: String*): Column = {
- array((colName +: colNames).map(col) : _*)
+ array((colName +: colNames).map(col): _*)
}
/**
- * Creates a new map column. The input columns must be grouped as key-value pairs, e.g.
- * (key1, value1, key2, value2, ...). The key columns must all have the same data type, and can't
- * be null. The value columns must all have the same data type.
+ * Creates a new map column. The input columns must be grouped as key-value pairs, e.g. (key1,
+ * value1, key2, value2, ...). The key columns must all have the same data type, and can't be
+ * null. The value columns must all have the same data type.
*
* @group map_funcs
* @since 2.0
@@ -1663,8 +1687,8 @@ object functions {
Column.fn("map_from_arrays", keys, values)
/**
- * Creates a map after splitting the text into key/value pairs using delimiters.
- * Both `pairDelim` and `keyValueDelim` are treated as regular expressions.
+ * Creates a map after splitting the text into key/value pairs using delimiters. Both
+ * `pairDelim` and `keyValueDelim` are treated as regular expressions.
*
* @group map_funcs
* @since 3.5.0
@@ -1673,8 +1697,8 @@ object functions {
Column.fn("str_to_map", text, pairDelim, keyValueDelim)
/**
- * Creates a map after splitting the text into key/value pairs using delimiters.
- * The `pairDelim` is treated as regular expressions.
+ * Creates a map after splitting the text into key/value pairs using delimiters. The `pairDelim`
+ * is treated as regular expressions.
*
* @group map_funcs
* @since 3.5.0
@@ -1702,15 +1726,15 @@ object functions {
* @group normal_funcs
* @since 1.5.0
*/
- def broadcast[DS[U] <: api.Dataset[U, DS]](df: DS[_]): df.type = {
+ def broadcast[DS[U] <: api.Dataset[U]](df: DS[_]): df.type = {
df.hint("broadcast").asInstanceOf[df.type]
}
/**
* Returns the first column that is not null, or null if all inputs are null.
*
- * For example, `coalesce(a, b, c)` will return a if a is not null,
- * or b if a is null and b is not null, or c if both a and b are null but c is not null.
+ * For example, `coalesce(a, b, c)` will return a if a is not null, or b if a is null and b is
+ * not null, or c if both a and b are null but c is not null.
*
* @group conditional_funcs
* @since 1.3.0
@@ -1745,13 +1769,13 @@ object functions {
/**
* A column expression that generates monotonically increasing 64-bit integers.
*
- * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
- * The current implementation puts the partition ID in the upper 31 bits, and the record number
- * within each partition in the lower 33 bits. The assumption is that the data frame has
- * less than 1 billion partitions, and each partition has less than 8 billion records.
+ * The generated ID is guaranteed to be monotonically increasing and unique, but not
+ * consecutive. The current implementation puts the partition ID in the upper 31 bits, and the
+ * record number within each partition in the lower 33 bits. The assumption is that the data
+ * frame has less than 1 billion partitions, and each partition has less than 8 billion records.
*
- * As an example, consider a `DataFrame` with two partitions, each with 3 records.
- * This expression would return the following IDs:
+ * As an example, consider a `DataFrame` with two partitions, each with 3 records. This
+ * expression would return the following IDs:
*
* {{{
* 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
@@ -1766,13 +1790,13 @@ object functions {
/**
* A column expression that generates monotonically increasing 64-bit integers.
*
- * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive.
- * The current implementation puts the partition ID in the upper 31 bits, and the record number
- * within each partition in the lower 33 bits. The assumption is that the data frame has
- * less than 1 billion partitions, and each partition has less than 8 billion records.
+ * The generated ID is guaranteed to be monotonically increasing and unique, but not
+ * consecutive. The current implementation puts the partition ID in the upper 31 bits, and the
+ * record number within each partition in the lower 33 bits. The assumption is that the data
+ * frame has less than 1 billion partitions, and each partition has less than 8 billion records.
*
- * As an example, consider a `DataFrame` with two partitions, each with 3 records.
- * This expression would return the following IDs:
+ * As an example, consider a `DataFrame` with two partitions, each with 3 records. This
+ * expression would return the following IDs:
*
* {{{
* 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594.
@@ -1828,7 +1852,8 @@ object functions {
* Generate a random column with independent and identically distributed (i.i.d.) samples
* uniformly distributed in [0.0, 1.0).
*
- * @note The function is non-deterministic in general case.
+ * @note
+ * The function is non-deterministic in general case.
*
* @group math_funcs
* @since 1.4.0
@@ -1839,7 +1864,8 @@ object functions {
* Generate a random column with independent and identically distributed (i.i.d.) samples
* uniformly distributed in [0.0, 1.0).
*
- * @note The function is non-deterministic in general case.
+ * @note
+ * The function is non-deterministic in general case.
*
* @group math_funcs
* @since 1.4.0
@@ -1847,10 +1873,11 @@ object functions {
def rand(): Column = rand(SparkClassUtils.random.nextLong)
/**
- * Generate a column with independent and identically distributed (i.i.d.) samples from
- * the standard normal distribution.
+ * Generate a column with independent and identically distributed (i.i.d.) samples from the
+ * standard normal distribution.
*
- * @note The function is non-deterministic in general case.
+ * @note
+ * The function is non-deterministic in general case.
*
* @group math_funcs
* @since 1.4.0
@@ -1858,10 +1885,11 @@ object functions {
def randn(seed: Long): Column = Column.fn("randn", lit(seed))
/**
- * Generate a column with independent and identically distributed (i.i.d.) samples from
- * the standard normal distribution.
+ * Generate a column with independent and identically distributed (i.i.d.) samples from the
+ * standard normal distribution.
*
- * @note The function is non-deterministic in general case.
+ * @note
+ * The function is non-deterministic in general case.
*
* @group math_funcs
* @since 1.4.0
@@ -1871,7 +1899,8 @@ object functions {
/**
* Partition ID.
*
- * @note This is non-deterministic because it depends on data partitioning and task scheduling.
+ * @note
+ * This is non-deterministic because it depends on data partitioning and task scheduling.
*
* @group misc_funcs
* @since 1.6.0
@@ -1921,8 +1950,7 @@ object functions {
def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right)
/**
- * Returns the remainder of `dividend``/``divisor`. Its result is
- * always null if `divisor` is 0.
+ * Returns the remainder of `dividend``/``divisor`. Its result is always null if `divisor` is 0.
*
* @group math_funcs
* @since 4.0.0
@@ -1956,11 +1984,10 @@ object functions {
def try_sum(e: Column): Column = Column.fn("try_sum", e)
/**
- * Creates a new struct column.
- * If the input column is a column in a `DataFrame`, or a derived column expression
- * that is named (i.e. aliased), its name would be retained as the StructField's name,
- * otherwise, the newly generated StructField's name would be auto generated as
- * `col` with a suffix `index + 1`, i.e. col1, col2, col3, ...
+ * Creates a new struct column. If the input column is a column in a `DataFrame`, or a derived
+ * column expression that is named (i.e. aliased), its name would be retained as the
+ * StructField's name, otherwise, the newly generated StructField's name would be auto generated
+ * as `col` with a suffix `index + 1`, i.e. col1, col2, col3, ...
*
* @group struct_funcs
* @since 1.4.0
@@ -1976,12 +2003,12 @@ object functions {
*/
@scala.annotation.varargs
def struct(colName: String, colNames: String*): Column = {
- struct((colName +: colNames).map(col) : _*)
+ struct((colName +: colNames).map(col): _*)
}
/**
- * Evaluates a list of conditions and returns one of multiple possible result expressions.
- * If otherwise is not defined at the end, null is returned for unmatched conditions.
+ * Evaluates a list of conditions and returns one of multiple possible result expressions. If
+ * otherwise is not defined at the end, null is returned for unmatched conditions.
*
* {{{
* // Example: encoding gender string column into integer.
@@ -2030,9 +2057,8 @@ object functions {
def bit_count(e: Column): Column = Column.fn("bit_count", e)
/**
- * Returns the value of the bit (0 or 1) at the specified position.
- * The positions are numbered from right to left, starting at zero.
- * The position argument cannot be negative.
+ * Returns the value of the bit (0 or 1) at the specified position. The positions are numbered
+ * from right to left, starting at zero. The position argument cannot be negative.
*
* @group bitwise_funcs
* @since 3.5.0
@@ -2040,9 +2066,8 @@ object functions {
def bit_get(e: Column, pos: Column): Column = Column.fn("bit_get", e, pos)
/**
- * Returns the value of the bit (0 or 1) at the specified position.
- * The positions are numbered from right to left, starting at zero.
- * The position argument cannot be negative.
+ * Returns the value of the bit (0 or 1) at the specified position. The positions are numbered
+ * from right to left, starting at zero. The position argument cannot be negative.
*
* @group bitwise_funcs
* @since 3.5.0
@@ -2074,7 +2099,8 @@ object functions {
def abs(e: Column): Column = Column.fn("abs", e)
/**
- * @return inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos`
+ * @return
+ * inverse cosine of `e` in radians, as if computed by `java.lang.Math.acos`
*
* @group math_funcs
* @since 1.4.0
@@ -2082,7 +2108,8 @@ object functions {
def acos(e: Column): Column = Column.fn("acos", e)
/**
- * @return inverse cosine of `columnName`, as if computed by `java.lang.Math.acos`
+ * @return
+ * inverse cosine of `columnName`, as if computed by `java.lang.Math.acos`
*
* @group math_funcs
* @since 1.4.0
@@ -2090,7 +2117,8 @@ object functions {
def acos(columnName: String): Column = acos(Column(columnName))
/**
- * @return inverse hyperbolic cosine of `e`
+ * @return
+ * inverse hyperbolic cosine of `e`
*
* @group math_funcs
* @since 3.1.0
@@ -2098,7 +2126,8 @@ object functions {
def acosh(e: Column): Column = Column.fn("acosh", e)
/**
- * @return inverse hyperbolic cosine of `columnName`
+ * @return
+ * inverse hyperbolic cosine of `columnName`
*
* @group math_funcs
* @since 3.1.0
@@ -2106,7 +2135,8 @@ object functions {
def acosh(columnName: String): Column = acosh(Column(columnName))
/**
- * @return inverse sine of `e` in radians, as if computed by `java.lang.Math.asin`
+ * @return
+ * inverse sine of `e` in radians, as if computed by `java.lang.Math.asin`
*
* @group math_funcs
* @since 1.4.0
@@ -2114,7 +2144,8 @@ object functions {
def asin(e: Column): Column = Column.fn("asin", e)
/**
- * @return inverse sine of `columnName`, as if computed by `java.lang.Math.asin`
+ * @return
+ * inverse sine of `columnName`, as if computed by `java.lang.Math.asin`
*
* @group math_funcs
* @since 1.4.0
@@ -2122,7 +2153,8 @@ object functions {
def asin(columnName: String): Column = asin(Column(columnName))
/**
- * @return inverse hyperbolic sine of `e`
+ * @return
+ * inverse hyperbolic sine of `e`
*
* @group math_funcs
* @since 3.1.0
@@ -2130,7 +2162,8 @@ object functions {
def asinh(e: Column): Column = Column.fn("asinh", e)
/**
- * @return inverse hyperbolic sine of `columnName`
+ * @return
+ * inverse hyperbolic sine of `columnName`
*
* @group math_funcs
* @since 3.1.0
@@ -2138,7 +2171,8 @@ object functions {
def asinh(columnName: String): Column = asinh(Column(columnName))
/**
- * @return inverse tangent of `e` as if computed by `java.lang.Math.atan`
+ * @return
+ * inverse tangent of `e` as if computed by `java.lang.Math.atan`
*
* @group math_funcs
* @since 1.4.0
@@ -2146,7 +2180,8 @@ object functions {
def atan(e: Column): Column = Column.fn("atan", e)
/**
- * @return inverse tangent of `columnName`, as if computed by `java.lang.Math.atan`
+ * @return
+ * inverse tangent of `columnName`, as if computed by `java.lang.Math.atan`
*
* @group math_funcs
* @since 1.4.0
@@ -2154,13 +2189,14 @@ object functions {
def atan(columnName: String): Column = atan(Column(columnName))
/**
- * @param y coordinate on y-axis
- * @param x coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param y
+ * coordinate on y-axis
+ * @param x
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2168,13 +2204,14 @@ object functions {
def atan2(y: Column, x: Column): Column = Column.fn("atan2", y, x)
/**
- * @param y coordinate on y-axis
- * @param xName coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param y
+ * coordinate on y-axis
+ * @param xName
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2182,13 +2219,14 @@ object functions {
def atan2(y: Column, xName: String): Column = atan2(y, Column(xName))
/**
- * @param yName coordinate on y-axis
- * @param x coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param yName
+ * coordinate on y-axis
+ * @param x
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2196,13 +2234,14 @@ object functions {
def atan2(yName: String, x: Column): Column = atan2(Column(yName), x)
/**
- * @param yName coordinate on y-axis
- * @param xName coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param yName
+ * coordinate on y-axis
+ * @param xName
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2211,13 +2250,14 @@ object functions {
atan2(Column(yName), Column(xName))
/**
- * @param y coordinate on y-axis
- * @param xValue coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param y
+ * coordinate on y-axis
+ * @param xValue
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2225,13 +2265,14 @@ object functions {
def atan2(y: Column, xValue: Double): Column = atan2(y, lit(xValue))
/**
- * @param yName coordinate on y-axis
- * @param xValue coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param yName
+ * coordinate on y-axis
+ * @param xValue
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2239,13 +2280,14 @@ object functions {
def atan2(yName: String, xValue: Double): Column = atan2(Column(yName), xValue)
/**
- * @param yValue coordinate on y-axis
- * @param x coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param yValue
+ * coordinate on y-axis
+ * @param x
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2253,13 +2295,14 @@ object functions {
def atan2(yValue: Double, x: Column): Column = atan2(lit(yValue), x)
/**
- * @param yValue coordinate on y-axis
- * @param xName coordinate on x-axis
- * @return the theta component of the point
- * (r, theta)
- * in polar coordinates that corresponds to the point
- * (x, y) in Cartesian coordinates,
- * as if computed by `java.lang.Math.atan2`
+ * @param yValue
+ * coordinate on y-axis
+ * @param xName
+ * coordinate on x-axis
+ * @return
+ * the theta component of the point (r, theta) in polar coordinates that
+ * corresponds to the point (x, y) in Cartesian coordinates, as if computed by
+ * `java.lang.Math.atan2`
*
* @group math_funcs
* @since 1.4.0
@@ -2267,7 +2310,8 @@ object functions {
def atan2(yValue: Double, xName: String): Column = atan2(yValue, Column(xName))
/**
- * @return inverse hyperbolic tangent of `e`
+ * @return
+ * inverse hyperbolic tangent of `e`
*
* @group math_funcs
* @since 3.1.0
@@ -2275,7 +2319,8 @@ object functions {
def atanh(e: Column): Column = Column.fn("atanh", e)
/**
- * @return inverse hyperbolic tangent of `columnName`
+ * @return
+ * inverse hyperbolic tangent of `columnName`
*
* @group math_funcs
* @since 3.1.0
@@ -2366,8 +2411,10 @@ object functions {
Column.fn("conv", num, lit(fromBase), lit(toBase))
/**
- * @param e angle in radians
- * @return cosine of the angle, as if computed by `java.lang.Math.cos`
+ * @param e
+ * angle in radians
+ * @return
+ * cosine of the angle, as if computed by `java.lang.Math.cos`
*
* @group math_funcs
* @since 1.4.0
@@ -2375,8 +2422,10 @@ object functions {
def cos(e: Column): Column = Column.fn("cos", e)
/**
- * @param columnName angle in radians
- * @return cosine of the angle, as if computed by `java.lang.Math.cos`
+ * @param columnName
+ * angle in radians
+ * @return
+ * cosine of the angle, as if computed by `java.lang.Math.cos`
*
* @group math_funcs
* @since 1.4.0
@@ -2384,8 +2433,10 @@ object functions {
def cos(columnName: String): Column = cos(Column(columnName))
/**
- * @param e hyperbolic angle
- * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh`
+ * @param e
+ * hyperbolic angle
+ * @return
+ * hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh`
*
* @group math_funcs
* @since 1.4.0
@@ -2393,8 +2444,10 @@ object functions {
def cosh(e: Column): Column = Column.fn("cosh", e)
/**
- * @param columnName hyperbolic angle
- * @return hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh`
+ * @param columnName
+ * hyperbolic angle
+ * @return
+ * hyperbolic cosine of the angle, as if computed by `java.lang.Math.cosh`
*
* @group math_funcs
* @since 1.4.0
@@ -2402,8 +2455,10 @@ object functions {
def cosh(columnName: String): Column = cosh(Column(columnName))
/**
- * @param e angle in radians
- * @return cotangent of the angle
+ * @param e
+ * angle in radians
+ * @return
+ * cotangent of the angle
*
* @group math_funcs
* @since 3.3.0
@@ -2411,8 +2466,10 @@ object functions {
def cot(e: Column): Column = Column.fn("cot", e)
/**
- * @param e angle in radians
- * @return cosecant of the angle
+ * @param e
+ * angle in radians
+ * @return
+ * cosecant of the angle
*
* @group math_funcs
* @since 3.3.0
@@ -2492,8 +2549,8 @@ object functions {
def floor(columnName: String): Column = floor(Column(columnName))
/**
- * Returns the greatest value of the list of values, skipping null values.
- * This function takes at least 2 parameters. It will return null iff all parameters are null.
+ * Returns the greatest value of the list of values, skipping null values. This function takes
+ * at least 2 parameters. It will return null iff all parameters are null.
*
* @group math_funcs
* @since 1.5.0
@@ -2502,8 +2559,8 @@ object functions {
def greatest(exprs: Column*): Column = Column.fn("greatest", exprs: _*)
/**
- * Returns the greatest value of the list of column names, skipping null values.
- * This function takes at least 2 parameters. It will return null iff all parameters are null.
+ * Returns the greatest value of the list of column names, skipping null values. This function
+ * takes at least 2 parameters. It will return null iff all parameters are null.
*
* @group math_funcs
* @since 1.5.0
@@ -2522,8 +2579,8 @@ object functions {
def hex(column: Column): Column = Column.fn("hex", column)
/**
- * Inverse of hex. Interprets each pair of characters as a hexadecimal number
- * and converts to the byte representation of number.
+ * Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to
+ * the byte representation of number.
*
* @group math_funcs
* @since 1.5.0
@@ -2596,8 +2653,8 @@ object functions {
def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName))
/**
- * Returns the least value of the list of values, skipping null values.
- * This function takes at least 2 parameters. It will return null iff all parameters are null.
+ * Returns the least value of the list of values, skipping null values. This function takes at
+ * least 2 parameters. It will return null iff all parameters are null.
*
* @group math_funcs
* @since 1.5.0
@@ -2606,8 +2663,8 @@ object functions {
def least(exprs: Column*): Column = Column.fn("least", exprs: _*)
/**
- * Returns the least value of the list of column names, skipping null values.
- * This function takes at least 2 parameters. It will return null iff all parameters are null.
+ * Returns the least value of the list of column names, skipping null values. This function
+ * takes at least 2 parameters. It will return null iff all parameters are null.
*
* @group math_funcs
* @since 1.5.0
@@ -2810,8 +2867,8 @@ object functions {
def pmod(dividend: Column, divisor: Column): Column = Column.fn("pmod", dividend, divisor)
/**
- * Returns the double value that is closest in value to the argument and
- * is equal to a mathematical integer.
+ * Returns the double value that is closest in value to the argument and is equal to a
+ * mathematical integer.
*
* @group math_funcs
* @since 1.4.0
@@ -2819,8 +2876,8 @@ object functions {
def rint(e: Column): Column = Column.fn("rint", e)
/**
- * Returns the double value that is closest in value to the argument and
- * is equal to a mathematical integer.
+ * Returns the double value that is closest in value to the argument and is equal to a
+ * mathematical integer.
*
* @group math_funcs
* @since 1.4.0
@@ -2836,8 +2893,8 @@ object functions {
def round(e: Column): Column = round(e, 0)
/**
- * Round the value of `e` to `scale` decimal places with HALF_UP round mode
- * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
+ * Round the value of `e` to `scale` decimal places with HALF_UP round mode if `scale` is
+ * greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @group math_funcs
* @since 1.5.0
@@ -2845,8 +2902,8 @@ object functions {
def round(e: Column, scale: Int): Column = Column.fn("round", e, lit(scale))
/**
- * Round the value of `e` to `scale` decimal places with HALF_UP round mode
- * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
+ * Round the value of `e` to `scale` decimal places with HALF_UP round mode if `scale` is
+ * greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @group math_funcs
* @since 4.0.0
@@ -2862,8 +2919,8 @@ object functions {
def bround(e: Column): Column = bround(e, 0)
/**
- * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
- * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
+ * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode if `scale` is
+ * greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @group math_funcs
* @since 2.0.0
@@ -2871,8 +2928,8 @@ object functions {
def bround(e: Column, scale: Int): Column = Column.fn("bround", e, lit(scale))
/**
- * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode
- * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
+ * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode if `scale` is
+ * greater than or equal to 0 or at integral part when `scale` is less than 0.
*
* @group math_funcs
* @since 4.0.0
@@ -2880,8 +2937,10 @@ object functions {
def bround(e: Column, scale: Column): Column = Column.fn("bround", e, scale)
/**
- * @param e angle in radians
- * @return secant of the angle
+ * @param e
+ * angle in radians
+ * @return
+ * secant of the angle
*
* @group math_funcs
* @since 3.3.0
@@ -2889,8 +2948,8 @@ object functions {
def sec(e: Column): Column = Column.fn("sec", e)
/**
- * Shift the given value numBits left. If the given value is a long value, this function
- * will return a long value else it will return an integer value.
+ * Shift the given value numBits left. If the given value is a long value, this function will
+ * return a long value else it will return an integer value.
*
* @group bitwise_funcs
* @since 1.5.0
@@ -2899,8 +2958,8 @@ object functions {
def shiftLeft(e: Column, numBits: Int): Column = shiftleft(e, numBits)
/**
- * Shift the given value numBits left. If the given value is a long value, this function
- * will return a long value else it will return an integer value.
+ * Shift the given value numBits left. If the given value is a long value, this function will
+ * return a long value else it will return an integer value.
*
* @group bitwise_funcs
* @since 3.2.0
@@ -2927,8 +2986,8 @@ object functions {
def shiftright(e: Column, numBits: Int): Column = Column.fn("shiftright", e, lit(numBits))
/**
- * Unsigned shift the given value numBits right. If the given value is a long value,
- * it will return a long value else it will return an integer value.
+ * Unsigned shift the given value numBits right. If the given value is a long value, it will
+ * return a long value else it will return an integer value.
*
* @group bitwise_funcs
* @since 1.5.0
@@ -2937,8 +2996,8 @@ object functions {
def shiftRightUnsigned(e: Column, numBits: Int): Column = shiftrightunsigned(e, numBits)
/**
- * Unsigned shift the given value numBits right. If the given value is a long value,
- * it will return a long value else it will return an integer value.
+ * Unsigned shift the given value numBits right. If the given value is a long value, it will
+ * return a long value else it will return an integer value.
*
* @group bitwise_funcs
* @since 3.2.0
@@ -2971,8 +3030,10 @@ object functions {
def signum(columnName: String): Column = signum(Column(columnName))
/**
- * @param e angle in radians
- * @return sine of the angle, as if computed by `java.lang.Math.sin`
+ * @param e
+ * angle in radians
+ * @return
+ * sine of the angle, as if computed by `java.lang.Math.sin`
*
* @group math_funcs
* @since 1.4.0
@@ -2980,8 +3041,10 @@ object functions {
def sin(e: Column): Column = Column.fn("sin", e)
/**
- * @param columnName angle in radians
- * @return sine of the angle, as if computed by `java.lang.Math.sin`
+ * @param columnName
+ * angle in radians
+ * @return
+ * sine of the angle, as if computed by `java.lang.Math.sin`
*
* @group math_funcs
* @since 1.4.0
@@ -2989,8 +3052,10 @@ object functions {
def sin(columnName: String): Column = sin(Column(columnName))
/**
- * @param e hyperbolic angle
- * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh`
+ * @param e
+ * hyperbolic angle
+ * @return
+ * hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh`
*
* @group math_funcs
* @since 1.4.0
@@ -2998,8 +3063,10 @@ object functions {
def sinh(e: Column): Column = Column.fn("sinh", e)
/**
- * @param columnName hyperbolic angle
- * @return hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh`
+ * @param columnName
+ * hyperbolic angle
+ * @return
+ * hyperbolic sine of the given value, as if computed by `java.lang.Math.sinh`
*
* @group math_funcs
* @since 1.4.0
@@ -3007,8 +3074,10 @@ object functions {
def sinh(columnName: String): Column = sinh(Column(columnName))
/**
- * @param e angle in radians
- * @return tangent of the given value, as if computed by `java.lang.Math.tan`
+ * @param e
+ * angle in radians
+ * @return
+ * tangent of the given value, as if computed by `java.lang.Math.tan`
*
* @group math_funcs
* @since 1.4.0
@@ -3016,8 +3085,10 @@ object functions {
def tan(e: Column): Column = Column.fn("tan", e)
/**
- * @param columnName angle in radians
- * @return tangent of the given value, as if computed by `java.lang.Math.tan`
+ * @param columnName
+ * angle in radians
+ * @return
+ * tangent of the given value, as if computed by `java.lang.Math.tan`
*
* @group math_funcs
* @since 1.4.0
@@ -3025,8 +3096,10 @@ object functions {
def tan(columnName: String): Column = tan(Column(columnName))
/**
- * @param e hyperbolic angle
- * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh`
+ * @param e
+ * hyperbolic angle
+ * @return
+ * hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh`
*
* @group math_funcs
* @since 1.4.0
@@ -3034,8 +3107,10 @@ object functions {
def tanh(e: Column): Column = Column.fn("tanh", e)
/**
- * @param columnName hyperbolic angle
- * @return hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh`
+ * @param columnName
+ * hyperbolic angle
+ * @return
+ * hyperbolic tangent of the given value, as if computed by `java.lang.Math.tanh`
*
* @group math_funcs
* @since 1.4.0
@@ -3057,10 +3132,13 @@ object functions {
def toDegrees(columnName: String): Column = degrees(Column(columnName))
/**
- * Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
+ * Converts an angle measured in radians to an approximately equivalent angle measured in
+ * degrees.
*
- * @param e angle in radians
- * @return angle in degrees, as if computed by `java.lang.Math.toDegrees`
+ * @param e
+ * angle in radians
+ * @return
+ * angle in degrees, as if computed by `java.lang.Math.toDegrees`
*
* @group math_funcs
* @since 2.1.0
@@ -3068,10 +3146,13 @@ object functions {
def degrees(e: Column): Column = Column.fn("degrees", e)
/**
- * Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
+ * Converts an angle measured in radians to an approximately equivalent angle measured in
+ * degrees.
*
- * @param columnName angle in radians
- * @return angle in degrees, as if computed by `java.lang.Math.toDegrees`
+ * @param columnName
+ * angle in radians
+ * @return
+ * angle in degrees, as if computed by `java.lang.Math.toDegrees`
*
* @group math_funcs
* @since 2.1.0
@@ -3093,10 +3174,13 @@ object functions {
def toRadians(columnName: String): Column = radians(Column(columnName))
/**
- * Converts an angle measured in degrees to an approximately equivalent angle measured in radians.
+ * Converts an angle measured in degrees to an approximately equivalent angle measured in
+ * radians.
*
- * @param e angle in degrees
- * @return angle in radians, as if computed by `java.lang.Math.toRadians`
+ * @param e
+ * angle in degrees
+ * @return
+ * angle in radians, as if computed by `java.lang.Math.toRadians`
*
* @group math_funcs
* @since 2.1.0
@@ -3104,10 +3188,13 @@ object functions {
def radians(e: Column): Column = Column.fn("radians", e)
/**
- * Converts an angle measured in degrees to an approximately equivalent angle measured in radians.
+ * Converts an angle measured in degrees to an approximately equivalent angle measured in
+ * radians.
*
- * @param columnName angle in degrees
- * @return angle in radians, as if computed by `java.lang.Math.toRadians`
+ * @param columnName
+ * angle in degrees
+ * @return
+ * angle in radians, as if computed by `java.lang.Math.toRadians`
*
* @group math_funcs
* @since 2.1.0
@@ -3115,15 +3202,20 @@ object functions {
def radians(columnName: String): Column = radians(Column(columnName))
/**
- * Returns the bucket number into which the value of this expression would fall
- * after being evaluated. Note that input arguments must follow conditions listed below;
- * otherwise, the method will return null.
+ * Returns the bucket number into which the value of this expression would fall after being
+ * evaluated. Note that input arguments must follow conditions listed below; otherwise, the
+ * method will return null.
*
- * @param v value to compute a bucket number in the histogram
- * @param min minimum value of the histogram
- * @param max maximum value of the histogram
- * @param numBucket the number of buckets
- * @return the bucket number into which the value would fall after being evaluated
+ * @param v
+ * value to compute a bucket number in the histogram
+ * @param min
+ * minimum value of the histogram
+ * @param max
+ * maximum value of the histogram
+ * @param numBucket
+ * the number of buckets
+ * @return
+ * the bucket number into which the value would fall after being evaluated
* @group math_funcs
* @since 3.5.0
*/
@@ -3167,8 +3259,8 @@ object functions {
def current_user(): Column = Column.fn("current_user")
/**
- * Calculates the MD5 digest of a binary column and returns the value
- * as a 32 character hex string.
+ * Calculates the MD5 digest of a binary column and returns the value as a 32 character hex
+ * string.
*
* @group hash_funcs
* @since 1.5.0
@@ -3176,8 +3268,8 @@ object functions {
def md5(e: Column): Column = Column.fn("md5", e)
/**
- * Calculates the SHA-1 digest of a binary column and returns the value
- * as a 40 character hex string.
+ * Calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex
+ * string.
*
* @group hash_funcs
* @since 1.5.0
@@ -3185,11 +3277,13 @@ object functions {
def sha1(e: Column): Column = Column.fn("sha1", e)
/**
- * Calculates the SHA-2 family of hash functions of a binary column and
- * returns the value as a hex string.
+ * Calculates the SHA-2 family of hash functions of a binary column and returns the value as a
+ * hex string.
*
- * @param e column to compute SHA-2 on.
- * @param numBits one of 224, 256, 384, or 512.
+ * @param e
+ * column to compute SHA-2 on.
+ * @param numBits
+ * one of 224, 256, 384, or 512.
*
* @group hash_funcs
* @since 1.5.0
@@ -3202,8 +3296,8 @@ object functions {
}
/**
- * Calculates the cyclic redundancy check value (CRC32) of a binary column and
- * returns the value as a bigint.
+ * Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value
+ * as a bigint.
*
* @group hash_funcs
* @since 1.5.0
@@ -3220,9 +3314,8 @@ object functions {
def hash(cols: Column*): Column = Column.fn("hash", cols: _*)
/**
- * Calculates the hash code of given columns using the 64-bit
- * variant of the xxHash algorithm, and returns the result as a long
- * column. The hash computation uses an initial seed of 42.
+ * Calculates the hash code of given columns using the 64-bit variant of the xxHash algorithm,
+ * and returns the result as a long column. The hash computation uses an initial seed of 42.
*
* @group hash_funcs
* @since 3.0.0
@@ -3255,8 +3348,8 @@ object functions {
def raise_error(c: Column): Column = Column.fn("raise_error", c)
/**
- * Returns the estimated number of unique values given the binary representation
- * of a Datasketches HllSketch.
+ * Returns the estimated number of unique values given the binary representation of a
+ * Datasketches HllSketch.
*
* @group misc_funcs
* @since 3.5.0
@@ -3264,8 +3357,8 @@ object functions {
def hll_sketch_estimate(c: Column): Column = Column.fn("hll_sketch_estimate", c)
/**
- * Returns the estimated number of unique values given the binary representation
- * of a Datasketches HllSketch.
+ * Returns the estimated number of unique values given the binary representation of a
+ * Datasketches HllSketch.
*
* @group misc_funcs
* @since 3.5.0
@@ -3275,9 +3368,8 @@ object functions {
}
/**
- * Merges two binary representations of Datasketches HllSketch objects, using a
- * Datasketches Union object. Throws an exception if sketches have different
- * lgConfigK values.
+ * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches
+ * Union object. Throws an exception if sketches have different lgConfigK values.
*
* @group misc_funcs
* @since 3.5.0
@@ -3286,9 +3378,8 @@ object functions {
Column.fn("hll_union", c1, c2)
/**
- * Merges two binary representations of Datasketches HllSketch objects, using a
- * Datasketches Union object. Throws an exception if sketches have different
- * lgConfigK values.
+ * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches
+ * Union object. Throws an exception if sketches have different lgConfigK values.
*
* @group misc_funcs
* @since 3.5.0
@@ -3298,9 +3389,9 @@ object functions {
}
/**
- * Merges two binary representations of Datasketches HllSketch objects, using a
- * Datasketches Union object. Throws an exception if sketches have different
- * lgConfigK values and allowDifferentLgConfigK is set to false.
+ * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches
+ * Union object. Throws an exception if sketches have different lgConfigK values and
+ * allowDifferentLgConfigK is set to false.
*
* @group misc_funcs
* @since 3.5.0
@@ -3309,15 +3400,17 @@ object functions {
Column.fn("hll_union", c1, c2, lit(allowDifferentLgConfigK))
/**
- * Merges two binary representations of Datasketches HllSketch objects, using a
- * Datasketches Union object. Throws an exception if sketches have different
- * lgConfigK values and allowDifferentLgConfigK is set to false.
+ * Merges two binary representations of Datasketches HllSketch objects, using a Datasketches
+ * Union object. Throws an exception if sketches have different lgConfigK values and
+ * allowDifferentLgConfigK is set to false.
*
* @group misc_funcs
* @since 3.5.0
*/
- def hll_union(columnName1: String, columnName2: String, allowDifferentLgConfigK: Boolean):
- Column = {
+ def hll_union(
+ columnName1: String,
+ columnName2: String,
+ allowDifferentLgConfigK: Boolean): Column = {
hll_union(Column(columnName1), Column(columnName2), allowDifferentLgConfigK)
}
@@ -3684,8 +3777,8 @@ object functions {
Column.fn("bitmap_bucket_number", col)
/**
- * Returns a bitmap with the positions of the bits set from all the values from the input column.
- * The input column will most likely be bitmap_bit_position().
+ * Returns a bitmap with the positions of the bits set from all the values from the input
+ * column. The input column will most likely be bitmap_bit_position().
*
* @group agg_funcs
* @since 3.5.0
@@ -3702,8 +3795,8 @@ object functions {
def bitmap_count(col: Column): Column = Column.fn("bitmap_count", col)
/**
- * Returns a bitmap that is the bitwise OR of all of the bitmaps from the input column.
- * The input column should be bitmaps created from bitmap_construct_agg().
+ * Returns a bitmap that is the bitwise OR of all of the bitmaps from the input column. The
+ * input column should be bitmaps created from bitmap_construct_agg().
*
* @group agg_funcs
* @since 3.5.0
@@ -3724,8 +3817,8 @@ object functions {
def ascii(e: Column): Column = Column.fn("ascii", e)
/**
- * Computes the BASE64 encoding of a binary column and returns it as a string column.
- * This is the reverse of unbase64.
+ * Computes the BASE64 encoding of a binary column and returns it as a string column. This is
+ * the reverse of unbase64.
*
* @group string_funcs
* @since 1.5.0
@@ -3741,10 +3834,11 @@ object functions {
def bit_length(e: Column): Column = Column.fn("bit_length", e)
/**
- * Concatenates multiple input string columns together into a single string column,
- * using the given separator.
+ * Concatenates multiple input string columns together into a single string column, using the
+ * given separator.
*
- * @note Input strings which are null are skipped.
+ * @note
+ * Input strings which are null are skipped.
*
* @group string_funcs
* @since 1.5.0
@@ -3754,9 +3848,9 @@ object functions {
Column.fn("concat_ws", lit(sep) +: exprs: _*)
/**
- * Computes the first argument into a string from a binary using the provided character set
- * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32').
- * If either argument is null, the result will also be null.
+ * Computes the first argument into a string from a binary using the provided character set (one
+ * of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'). If either
+ * argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
@@ -3765,9 +3859,9 @@ object functions {
Column.fn("decode", value, lit(charset))
/**
- * Computes the first argument into a binary from a string using the provided character set
- * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32').
- * If either argument is null, the result will also be null.
+ * Computes the first argument into a binary from a string using the provided character set (one
+ * of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16', 'UTF-32'). If either
+ * argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
@@ -3776,11 +3870,11 @@ object functions {
Column.fn("encode", value, lit(charset))
/**
- * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places
- * with HALF_EVEN round mode, and returns the result as a string column.
+ * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with
+ * HALF_EVEN round mode, and returns the result as a string column.
*
- * If d is 0, the result has no decimal point or fractional part.
- * If d is less than 0, the result will be null.
+ * If d is 0, the result has no decimal point or fractional part. If d is less than 0, the
+ * result will be null.
*
* @group string_funcs
* @since 1.5.0
@@ -3798,8 +3892,8 @@ object functions {
Column.fn("format_string", lit(format) +: arguments: _*)
/**
- * Returns a new string column by converting the first letter of each word to uppercase.
- * Words are delimited by whitespace.
+ * Returns a new string column by converting the first letter of each word to uppercase. Words
+ * are delimited by whitespace.
*
* For example, "hello world" will become "Hello World".
*
@@ -3809,11 +3903,12 @@ object functions {
def initcap(e: Column): Column = Column.fn("initcap", e)
/**
- * Locate the position of the first occurrence of substr column in the given string.
- * Returns null if either of the arguments are null.
+ * Locate the position of the first occurrence of substr column in the given string. Returns
+ * null if either of the arguments are null.
*
- * @note The position is not zero based, but 1 based index. Returns 0 if substr
- * could not be found in str.
+ * @note
+ * The position is not zero based, but 1 based index. Returns 0 if substr could not be found
+ * in str.
*
* @group string_funcs
* @since 1.5.0
@@ -3821,8 +3916,8 @@ object functions {
def instr(str: Column, substring: String): Column = Column.fn("instr", str, lit(substring))
/**
- * Computes the character length of a given string or number of bytes of a binary string.
- * The length of character strings include the trailing spaces. The length of binary strings
+ * Computes the character length of a given string or number of bytes of a binary string. The
+ * length of character strings include the trailing spaces. The length of binary strings
* includes binary zeros.
*
* @group string_funcs
@@ -3831,8 +3926,8 @@ object functions {
def length(e: Column): Column = Column.fn("length", e)
/**
- * Computes the character length of a given string or number of bytes of a binary string.
- * The length of character strings include the trailing spaces. The length of binary strings
+ * Computes the character length of a given string or number of bytes of a binary string. The
+ * length of character strings include the trailing spaces. The length of binary strings
* includes binary zeros.
*
* @group string_funcs
@@ -3849,9 +3944,10 @@ object functions {
def lower(e: Column): Column = Column.fn("lower", e)
/**
- * Computes the Levenshtein distance of the two given string columns if it's less than or
- * equal to a given threshold.
- * @return result distance, or -1
+ * Computes the Levenshtein distance of the two given string columns if it's less than or equal
+ * to a given threshold.
+ * @return
+ * result distance, or -1
* @group string_funcs
* @since 3.5.0
*/
@@ -3868,8 +3964,9 @@ object functions {
/**
* Locate the position of the first occurrence of substr.
*
- * @note The position is not zero based, but 1 based index. Returns 0 if substr
- * could not be found in str.
+ * @note
+ * The position is not zero based, but 1 based index. Returns 0 if substr could not be found
+ * in str.
*
* @group string_funcs
* @since 1.5.0
@@ -3879,8 +3976,9 @@ object functions {
/**
* Locate the position of the first occurrence of substr in a string column, after position pos.
*
- * @note The position is not zero based, but 1 based index. returns 0 if substr
- * could not be found in str.
+ * @note
+ * The position is not zero based, but 1 based index. returns 0 if substr could not be found
+ * in str.
*
* @group string_funcs
* @since 1.5.0
@@ -3889,8 +3987,8 @@ object functions {
Column.fn("locate", lit(substr), str, lit(pos))
/**
- * Left-pad the string column with pad to a length of len. If the string column is longer
- * than len, the return value is shortened to len characters.
+ * Left-pad the string column with pad to a length of len. If the string column is longer than
+ * len, the return value is shortened to len characters.
*
* @group string_funcs
* @since 1.5.0
@@ -3972,8 +4070,8 @@ object functions {
def regexp_like(str: Column, regexp: Column): Column = Column.fn("regexp_like", str, regexp)
/**
- * Returns a count of the number of times that the regular expression pattern `regexp`
- * is matched in the string `str`.
+ * Returns a count of the number of times that the regular expression pattern `regexp` is
+ * matched in the string `str`.
*
* @group string_funcs
* @since 3.5.0
@@ -3981,10 +4079,10 @@ object functions {
def regexp_count(str: Column, regexp: Column): Column = Column.fn("regexp_count", str, regexp)
/**
- * Extract a specific group matched by a Java regex, from the specified string column.
- * If the regex did not match, or the specified group did not match, an empty string is returned.
- * if the specified group index exceeds the group count of regex, an IllegalArgumentException
- * will be thrown.
+ * Extract a specific group matched by a Java regex, from the specified string column. If the
+ * regex did not match, or the specified group did not match, an empty string is returned. if
+ * the specified group index exceeds the group count of regex, an IllegalArgumentException will
+ * be thrown.
*
* @group string_funcs
* @since 1.5.0
@@ -3993,8 +4091,8 @@ object functions {
Column.fn("regexp_extract", e, lit(exp), lit(groupIdx))
/**
- * Extract all strings in the `str` that match the `regexp` expression and
- * corresponding to the first regex group index.
+ * Extract all strings in the `str` that match the `regexp` expression and corresponding to the
+ * first regex group index.
*
* @group string_funcs
* @since 3.5.0
@@ -4003,8 +4101,8 @@ object functions {
Column.fn("regexp_extract_all", str, regexp)
/**
- * Extract all strings in the `str` that match the `regexp` expression and
- * corresponding to the regex group index.
+ * Extract all strings in the `str` that match the `regexp` expression and corresponding to the
+ * regex group index.
*
* @group string_funcs
* @since 3.5.0
@@ -4040,9 +4138,9 @@ object functions {
def regexp_substr(str: Column, regexp: Column): Column = Column.fn("regexp_substr", str, regexp)
/**
- * Searches a string for a regular expression and returns an integer that indicates
- * the beginning position of the matched substring. Positions are 1-based, not 0-based.
- * If no match is found, returns 0.
+ * Searches a string for a regular expression and returns an integer that indicates the
+ * beginning position of the matched substring. Positions are 1-based, not 0-based. If no match
+ * is found, returns 0.
*
* @group string_funcs
* @since 3.5.0
@@ -4050,9 +4148,9 @@ object functions {
def regexp_instr(str: Column, regexp: Column): Column = Column.fn("regexp_instr", str, regexp)
/**
- * Searches a string for a regular expression and returns an integer that indicates
- * the beginning position of the matched substring. Positions are 1-based, not 0-based.
- * If no match is found, returns 0.
+ * Searches a string for a regular expression and returns an integer that indicates the
+ * beginning position of the matched substring. Positions are 1-based, not 0-based. If no match
+ * is found, returns 0.
*
* @group string_funcs
* @since 3.5.0
@@ -4061,8 +4159,8 @@ object functions {
Column.fn("regexp_instr", str, regexp, idx)
/**
- * Decodes a BASE64 encoded string column and returns it as a binary column.
- * This is the reverse of base64.
+ * Decodes a BASE64 encoded string column and returns it as a binary column. This is the reverse
+ * of base64.
*
* @group string_funcs
* @since 1.5.0
@@ -4070,8 +4168,8 @@ object functions {
def unbase64(e: Column): Column = Column.fn("unbase64", e)
/**
- * Right-pad the string column with pad to a length of len. If the string column is longer
- * than len, the return value is shortened to len characters.
+ * Right-pad the string column with pad to a length of len. If the string column is longer than
+ * len, the return value is shortened to len characters.
*
* @group string_funcs
* @since 1.5.0
@@ -4199,11 +4297,11 @@ object functions {
Column.fn("split", str, pattern, limit)
/**
- * Substring starts at `pos` and is of length `len` when str is String type or
- * returns the slice of byte array that starts at `pos` in byte and is of length `len`
- * when str is Binary type
+ * Substring starts at `pos` and is of length `len` when str is String type or returns the slice
+ * of byte array that starts at `pos` in byte and is of length `len` when str is Binary type
*
- * @note The position is not zero based, but 1 based index.
+ * @note
+ * The position is not zero based, but 1 based index.
*
* @group string_funcs
* @since 1.5.0
@@ -4212,11 +4310,11 @@ object functions {
Column.fn("substring", str, lit(pos), lit(len))
/**
- * Substring starts at `pos` and is of length `len` when str is String type or
- * returns the slice of byte array that starts at `pos` in byte and is of length `len`
- * when str is Binary type
+ * Substring starts at `pos` and is of length `len` when str is String type or returns the slice
+ * of byte array that starts at `pos` in byte and is of length `len` when str is Binary type
*
- * @note The position is not zero based, but 1 based index.
+ * @note
+ * The position is not zero based, but 1 based index.
*
* @group string_funcs
* @since 4.0.0
@@ -4225,8 +4323,8 @@ object functions {
Column.fn("substring", str, pos, len)
/**
- * Returns the substring from string str before count occurrences of the delimiter delim.
- * If count is positive, everything the left of the final delimiter (counting from left) is
+ * Returns the substring from string str before count occurrences of the delimiter delim. If
+ * count is positive, everything the left of the final delimiter (counting from left) is
* returned. If count is negative, every to the right of the final delimiter (counting from the
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*
@@ -4236,8 +4334,8 @@ object functions {
Column.fn("substring_index", str, lit(delim), lit(count))
/**
- * Overlay the specified portion of `src` with `replace`,
- * starting from byte position `pos` of `src` and proceeding for `len` bytes.
+ * Overlay the specified portion of `src` with `replace`, starting from byte position `pos` of
+ * `src` and proceeding for `len` bytes.
*
* @group string_funcs
* @since 3.0.0
@@ -4246,8 +4344,8 @@ object functions {
Column.fn("overlay", src, replace, pos, len)
/**
- * Overlay the specified portion of `src` with `replace`,
- * starting from byte position `pos` of `src`.
+ * Overlay the specified portion of `src` with `replace`, starting from byte position `pos` of
+ * `src`.
*
* @group string_funcs
* @since 3.0.0
@@ -4264,18 +4362,26 @@ object functions {
Column.fn("sentences", string, language, country)
/**
- * Splits a string into arrays of sentences, where each sentence is an array of words.
- * The default locale is used.
+ * Splits a string into arrays of sentences, where each sentence is an array of words. The
+ * default `country`('') is used.
+ * @group string_funcs
+ * @since 4.0.0
+ */
+ def sentences(string: Column, language: Column): Column =
+ Column.fn("sentences", string, language)
+
+ /**
+ * Splits a string into arrays of sentences, where each sentence is an array of words. The
+ * default locale is used.
* @group string_funcs
* @since 3.2.0
*/
def sentences(string: Column): Column = Column.fn("sentences", string)
/**
- * Translate any character in the src by a character in replaceString.
- * The characters in replaceString correspond to the characters in matchingString.
- * The translate will happen when any character in the string matches the character
- * in the `matchingString`.
+ * Translate any character in the src by a character in replaceString. The characters in
+ * replaceString correspond to the characters in matchingString. The translate will happen when
+ * any character in the string matches the character in the `matchingString`.
*
* @group string_funcs
* @since 1.5.0
@@ -4307,10 +4413,10 @@ object functions {
def upper(e: Column): Column = Column.fn("upper", e)
/**
- * Converts the input `e` to a binary value based on the supplied `format`.
- * The `format` can be a case-insensitive string literal of "hex", "utf-8", "utf8", or "base64".
- * By default, the binary format for conversion is "hex" if `format` is omitted.
- * The function returns NULL if at least one of the input parameters is NULL.
+ * Converts the input `e` to a binary value based on the supplied `format`. The `format` can be
+ * a case-insensitive string literal of "hex", "utf-8", "utf8", or "base64". By default, the
+ * binary format for conversion is "hex" if `format` is omitted. The function returns NULL if at
+ * least one of the input parameters is NULL.
*
* @group string_funcs
* @since 3.5.0
@@ -4318,8 +4424,8 @@ object functions {
def to_binary(e: Column, f: Column): Column = Column.fn("to_binary", e, f)
/**
- * Converts the input `e` to a binary value based on the default format "hex".
- * The function returns NULL if at least one of the input parameters is NULL.
+ * Converts the input `e` to a binary value based on the default format "hex". The function
+ * returns NULL if at least one of the input parameters is NULL.
*
* @group string_funcs
* @since 3.5.0
@@ -4328,32 +4434,27 @@ object functions {
// scalastyle:off line.size.limit
/**
- * Convert `e` to a string based on the `format`.
- * Throws an exception if the conversion fails. The format can consist of the following
- * characters, case insensitive:
- * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format
- * string matches a sequence of digits in the input value, generating a result string of the
- * same length as the corresponding sequence in the format string. The result string is
- * left-padded with zeros if the 0/9 sequence comprises more digits than the matching part of
- * the decimal value, starts with 0, and is before the decimal point. Otherwise, it is
- * padded with spaces.
- * '.' or 'D': Specifies the position of the decimal point (optional, only allowed once).
- * ',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be
- * a 0 or 9 to the left and right of each grouping separator.
- * '$': Specifies the location of the $ currency sign. This character may only be specified
- * once.
- * 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at
- * the beginning or end of the format string). Note that 'S' prints '+' for positive values
- * but 'MI' prints a space.
- * 'PR': Only allowed at the end of the format string; specifies that the result string will be
- * wrapped by angle brackets if the input value is negative.
- *
- * If `e` is a datetime, `format` shall be a valid datetime pattern, see
- * Datetime Patterns.
- * If `e` is a binary, it is converted to a string in one of the formats:
- * 'base64': a base 64 string.
- * 'hex': a string in the hexadecimal format.
- * 'utf-8': the input binary is decoded to UTF-8 string.
+ * Convert `e` to a string based on the `format`. Throws an exception if the conversion fails.
+ * The format can consist of the following characters, case insensitive: '0' or '9': Specifies
+ * an expected digit between 0 and 9. A sequence of 0 or 9 in the format string matches a
+ * sequence of digits in the input value, generating a result string of the same length as the
+ * corresponding sequence in the format string. The result string is left-padded with zeros if
+ * the 0/9 sequence comprises more digits than the matching part of the decimal value, starts
+ * with 0, and is before the decimal point. Otherwise, it is padded with spaces. '.' or 'D':
+ * Specifies the position of the decimal point (optional, only allowed once). ',' or 'G':
+ * Specifies the position of the grouping (thousands) separator (,). There must be a 0 or 9 to
+ * the left and right of each grouping separator. '$': Specifies the location of the $ currency
+ * sign. This character may only be specified once. 'S' or 'MI': Specifies the position of a '-'
+ * or '+' sign (optional, only allowed once at the beginning or end of the format string). Note
+ * that 'S' prints '+' for positive values but 'MI' prints a space. 'PR': Only allowed at the
+ * end of the format string; specifies that the result string will be wrapped by angle brackets
+ * if the input value is negative.
+ *
+ * If `e` is a datetime, `format` shall be a valid datetime pattern, see Datetime
+ * Patterns. If `e` is a binary, it is converted to a string in one of the formats:
+ * 'base64': a base 64 string. 'hex': a string in the hexadecimal format. 'utf-8': the input
+ * binary is decoded to UTF-8 string.
*
* @group string_funcs
* @since 3.5.0
@@ -4363,32 +4464,27 @@ object functions {
// scalastyle:off line.size.limit
/**
- * Convert `e` to a string based on the `format`.
- * Throws an exception if the conversion fails. The format can consist of the following
- * characters, case insensitive:
- * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format
- * string matches a sequence of digits in the input value, generating a result string of the
- * same length as the corresponding sequence in the format string. The result string is
- * left-padded with zeros if the 0/9 sequence comprises more digits than the matching part of
- * the decimal value, starts with 0, and is before the decimal point. Otherwise, it is
- * padded with spaces.
- * '.' or 'D': Specifies the position of the decimal point (optional, only allowed once).
- * ',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be
- * a 0 or 9 to the left and right of each grouping separator.
- * '$': Specifies the location of the $ currency sign. This character may only be specified
- * once.
- * 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at
- * the beginning or end of the format string). Note that 'S' prints '+' for positive values
- * but 'MI' prints a space.
- * 'PR': Only allowed at the end of the format string; specifies that the result string will be
- * wrapped by angle brackets if the input value is negative.
- *
- * If `e` is a datetime, `format` shall be a valid datetime pattern, see
- * Datetime Patterns.
- * If `e` is a binary, it is converted to a string in one of the formats:
- * 'base64': a base 64 string.
- * 'hex': a string in the hexadecimal format.
- * 'utf-8': the input binary is decoded to UTF-8 string.
+ * Convert `e` to a string based on the `format`. Throws an exception if the conversion fails.
+ * The format can consist of the following characters, case insensitive: '0' or '9': Specifies
+ * an expected digit between 0 and 9. A sequence of 0 or 9 in the format string matches a
+ * sequence of digits in the input value, generating a result string of the same length as the
+ * corresponding sequence in the format string. The result string is left-padded with zeros if
+ * the 0/9 sequence comprises more digits than the matching part of the decimal value, starts
+ * with 0, and is before the decimal point. Otherwise, it is padded with spaces. '.' or 'D':
+ * Specifies the position of the decimal point (optional, only allowed once). ',' or 'G':
+ * Specifies the position of the grouping (thousands) separator (,). There must be a 0 or 9 to
+ * the left and right of each grouping separator. '$': Specifies the location of the $ currency
+ * sign. This character may only be specified once. 'S' or 'MI': Specifies the position of a '-'
+ * or '+' sign (optional, only allowed once at the beginning or end of the format string). Note
+ * that 'S' prints '+' for positive values but 'MI' prints a space. 'PR': Only allowed at the
+ * end of the format string; specifies that the result string will be wrapped by angle brackets
+ * if the input value is negative.
+ *
+ * If `e` is a datetime, `format` shall be a valid datetime pattern, see Datetime
+ * Patterns. If `e` is a binary, it is converted to a string in one of the formats:
+ * 'base64': a base 64 string. 'hex': a string in the hexadecimal format. 'utf-8': the input
+ * binary is decoded to UTF-8 string.
*
* @group string_funcs
* @since 3.5.0
@@ -4397,24 +4493,21 @@ object functions {
def to_varchar(e: Column, format: Column): Column = Column.fn("to_varchar", e, format)
/**
- * Convert string 'e' to a number based on the string format 'format'.
- * Throws an exception if the conversion fails. The format can consist of the following
- * characters, case insensitive:
- * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format
- * string matches a sequence of digits in the input string. If the 0/9 sequence starts with
- * 0 and is before the decimal point, it can only match a digit sequence of the same size.
- * Otherwise, if the sequence starts with 9 or is after the decimal point, it can match a
- * digit sequence that has the same or smaller size.
- * '.' or 'D': Specifies the position of the decimal point (optional, only allowed once).
- * ',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be
- * a 0 or 9 to the left and right of each grouping separator. 'expr' must match the
- * grouping separator relevant for the size of the number.
- * '$': Specifies the location of the $ currency sign. This character may only be specified
- * once.
- * 'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at
- * the beginning or end of the format string). Note that 'S' allows '-' but 'MI' does not.
- * 'PR': Only allowed at the end of the format string; specifies that 'expr' indicates a
- * negative number with wrapping angled brackets.
+ * Convert string 'e' to a number based on the string format 'format'. Throws an exception if
+ * the conversion fails. The format can consist of the following characters, case insensitive:
+ * '0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format
+ * string matches a sequence of digits in the input string. If the 0/9 sequence starts with 0
+ * and is before the decimal point, it can only match a digit sequence of the same size.
+ * Otherwise, if the sequence starts with 9 or is after the decimal point, it can match a digit
+ * sequence that has the same or smaller size. '.' or 'D': Specifies the position of the decimal
+ * point (optional, only allowed once). ',' or 'G': Specifies the position of the grouping
+ * (thousands) separator (,). There must be a 0 or 9 to the left and right of each grouping
+ * separator. 'expr' must match the grouping separator relevant for the size of the number. '$':
+ * Specifies the location of the $ currency sign. This character may only be specified once. 'S'
+ * or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at the
+ * beginning or end of the format string). Note that 'S' allows '-' but 'MI' does not. 'PR':
+ * Only allowed at the end of the format string; specifies that 'expr' indicates a negative
+ * number with wrapping angled brackets.
*
* @group string_funcs
* @since 3.5.0
@@ -4452,11 +4545,10 @@ object functions {
def replace(src: Column, search: Column): Column = Column.fn("replace", src, search)
/**
- * Splits `str` by delimiter and return requested part of the split (1-based).
- * If any input is null, returns null. if `partNum` is out of range of split parts,
- * returns empty string. If `partNum` is 0, throws an error. If `partNum` is negative,
- * the parts are counted backward from the end of the string.
- * If the `delimiter` is an empty string, the `str` is not split.
+ * Splits `str` by delimiter and return requested part of the split (1-based). If any input is
+ * null, returns null. if `partNum` is out of range of split parts, returns empty string. If
+ * `partNum` is 0, throws an error. If `partNum` is negative, the parts are counted backward
+ * from the end of the string. If the `delimiter` is an empty string, the `str` is not split.
*
* @group string_funcs
* @since 3.5.0
@@ -4465,8 +4557,8 @@ object functions {
Column.fn("split_part", str, delimiter, partNum)
/**
- * Returns the substring of `str` that starts at `pos` and is of length `len`,
- * or the slice of byte array that starts at `pos` and is of length `len`.
+ * Returns the substring of `str` that starts at `pos` and is of length `len`, or the slice of
+ * byte array that starts at `pos` and is of length `len`.
*
* @group string_funcs
* @since 3.5.0
@@ -4475,8 +4567,8 @@ object functions {
Column.fn("substr", str, pos, len)
/**
- * Returns the substring of `str` that starts at `pos`,
- * or the slice of byte array that starts at `pos`.
+ * Returns the substring of `str` that starts at `pos`, or the slice of byte array that starts
+ * at `pos`.
*
* @group string_funcs
* @since 3.5.0
@@ -4511,8 +4603,8 @@ object functions {
Column.fn("printf", (format +: arguments): _*)
/**
- * Decodes a `str` in 'application/x-www-form-urlencoded' format
- * using a specific encoding scheme.
+ * Decodes a `str` in 'application/x-www-form-urlencoded' format using a specific encoding
+ * scheme.
*
* @group url_funcs
* @since 3.5.0
@@ -4520,8 +4612,8 @@ object functions {
def url_decode(str: Column): Column = Column.fn("url_decode", str)
/**
- * This is a special version of `url_decode` that performs the same operation, but returns
- * a NULL value instead of raising an error if the decoding cannot be performed.
+ * This is a special version of `url_decode` that performs the same operation, but returns a
+ * NULL value instead of raising an error if the decoding cannot be performed.
*
* @group url_funcs
* @since 4.0.0
@@ -4529,8 +4621,8 @@ object functions {
def try_url_decode(str: Column): Column = Column.fn("try_url_decode", str)
/**
- * Translates a string into 'application/x-www-form-urlencoded' format
- * using a specific encoding scheme.
+ * Translates a string into 'application/x-www-form-urlencoded' format using a specific encoding
+ * scheme.
*
* @group url_funcs
* @since 3.5.0
@@ -4538,8 +4630,8 @@ object functions {
def url_encode(str: Column): Column = Column.fn("url_encode", str)
/**
- * Returns the position of the first occurrence of `substr` in `str` after position `start`.
- * The given `start` and return value are 1-based.
+ * Returns the position of the first occurrence of `substr` in `str` after position `start`. The
+ * given `start` and return value are 1-based.
*
* @group string_funcs
* @since 3.5.0
@@ -4548,8 +4640,8 @@ object functions {
Column.fn("position", substr, str, start)
/**
- * Returns the position of the first occurrence of `substr` in `str` after position `1`.
- * The return value are 1-based.
+ * Returns the position of the first occurrence of `substr` in `str` after position `1`. The
+ * return value are 1-based.
*
* @group string_funcs
* @since 3.5.0
@@ -4558,9 +4650,9 @@ object functions {
Column.fn("position", substr, str)
/**
- * Returns a boolean. The value is True if str ends with suffix.
- * Returns NULL if either input expression is NULL. Otherwise, returns False.
- * Both str or suffix must be of STRING or BINARY type.
+ * Returns a boolean. The value is True if str ends with suffix. Returns NULL if either input
+ * expression is NULL. Otherwise, returns False. Both str or suffix must be of STRING or BINARY
+ * type.
*
* @group string_funcs
* @since 3.5.0
@@ -4569,9 +4661,9 @@ object functions {
Column.fn("endswith", str, suffix)
/**
- * Returns a boolean. The value is True if str starts with prefix.
- * Returns NULL if either input expression is NULL. Otherwise, returns False.
- * Both str or prefix must be of STRING or BINARY type.
+ * Returns a boolean. The value is True if str starts with prefix. Returns NULL if either input
+ * expression is NULL. Otherwise, returns False. Both str or prefix must be of STRING or BINARY
+ * type.
*
* @group string_funcs
* @since 3.5.0
@@ -4580,8 +4672,8 @@ object functions {
Column.fn("startswith", str, prefix)
/**
- * Returns the ASCII character having the binary equivalent to `n`.
- * If n is larger than 256 the result is equivalent to char(n % 256)
+ * Returns the ASCII character having the binary equivalent to `n`. If n is larger than 256 the
+ * result is equivalent to char(n % 256)
*
* @group string_funcs
* @since 3.5.0
@@ -4633,9 +4725,8 @@ object functions {
def try_to_number(e: Column, format: Column): Column = Column.fn("try_to_number", e, format)
/**
- * Returns the character length of string data or number of bytes of binary data.
- * The length of string data includes the trailing spaces.
- * The length of binary data includes binary zeros.
+ * Returns the character length of string data or number of bytes of binary data. The length of
+ * string data includes the trailing spaces. The length of binary data includes binary zeros.
*
* @group string_funcs
* @since 3.5.0
@@ -4643,9 +4734,8 @@ object functions {
def char_length(str: Column): Column = Column.fn("char_length", str)
/**
- * Returns the character length of string data or number of bytes of binary data.
- * The length of string data includes the trailing spaces.
- * The length of binary data includes binary zeros.
+ * Returns the character length of string data or number of bytes of binary data. The length of
+ * string data includes the trailing spaces. The length of binary data includes binary zeros.
*
* @group string_funcs
* @since 3.5.0
@@ -4653,8 +4743,8 @@ object functions {
def character_length(str: Column): Column = Column.fn("character_length", str)
/**
- * Returns the ASCII character having the binary equivalent to `n`.
- * If n is larger than 256 the result is equivalent to chr(n % 256)
+ * Returns the ASCII character having the binary equivalent to `n`. If n is larger than 256 the
+ * result is equivalent to chr(n % 256)
*
* @group string_funcs
* @since 3.5.0
@@ -4662,9 +4752,9 @@ object functions {
def chr(n: Column): Column = Column.fn("chr", n)
/**
- * Returns a boolean. The value is True if right is found inside left.
- * Returns NULL if either input expression is NULL. Otherwise, returns False.
- * Both left or right must be of STRING or BINARY type.
+ * Returns a boolean. The value is True if right is found inside left. Returns NULL if either
+ * input expression is NULL. Otherwise, returns False. Both left or right must be of STRING or
+ * BINARY type.
*
* @group string_funcs
* @since 3.5.0
@@ -4672,10 +4762,10 @@ object functions {
def contains(left: Column, right: Column): Column = Column.fn("contains", left, right)
/**
- * Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
- * The function returns NULL if the index exceeds the length of the array
- * and `spark.sql.ansi.enabled` is set to false. If `spark.sql.ansi.enabled` is set to true,
- * it throws ArrayIndexOutOfBoundsException for invalid indices.
+ * Returns the `n`-th input, e.g., returns `input2` when `n` is 2. The function returns NULL if
+ * the index exceeds the length of the array and `spark.sql.ansi.enabled` is set to false. If
+ * `spark.sql.ansi.enabled` is set to true, it throws ArrayIndexOutOfBoundsException for invalid
+ * indices.
*
* @group string_funcs
* @since 3.5.0
@@ -4684,9 +4774,9 @@ object functions {
def elt(inputs: Column*): Column = Column.fn("elt", inputs: _*)
/**
- * Returns the index (1-based) of the given string (`str`) in the comma-delimited
- * list (`strArray`). Returns 0, if the string was not found or if the given string (`str`)
- * contains a comma.
+ * Returns the index (1-based) of the given string (`str`) in the comma-delimited list
+ * (`strArray`). Returns 0, if the string was not found or if the given string (`str`) contains
+ * a comma.
*
* @group string_funcs
* @since 3.5.0
@@ -4748,8 +4838,8 @@ object functions {
def ucase(str: Column): Column = Column.fn("ucase", str)
/**
- * Returns the leftmost `len`(`len` can be string type) characters from the string `str`,
- * if `len` is less or equal than 0 the result is an empty string.
+ * Returns the leftmost `len`(`len` can be string type) characters from the string `str`, if
+ * `len` is less or equal than 0 the result is an empty string.
*
* @group string_funcs
* @since 3.5.0
@@ -4757,8 +4847,8 @@ object functions {
def left(str: Column, len: Column): Column = Column.fn("left", str, len)
/**
- * Returns the rightmost `len`(`len` can be string type) characters from the string `str`,
- * if `len` is less or equal than 0 the result is an empty string.
+ * Returns the rightmost `len`(`len` can be string type) characters from the string `str`, if
+ * `len` is less or equal than 0 the result is an empty string.
*
* @group string_funcs
* @since 3.5.0
@@ -4772,23 +4862,29 @@ object functions {
/**
* Returns the date that is `numMonths` after `startDate`.
*
- * @param startDate A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param numMonths The number of months to add to `startDate`, can be negative to subtract months
- * @return A date, or null if `startDate` was a string that could not be cast to a date
+ * @param startDate
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param numMonths
+ * The number of months to add to `startDate`, can be negative to subtract months
+ * @return
+ * A date, or null if `startDate` was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
- def add_months(startDate: Column, numMonths: Int): Column = add_months(startDate, lit(numMonths))
+ def add_months(startDate: Column, numMonths: Int): Column =
+ add_months(startDate, lit(numMonths))
/**
* Returns the date that is `numMonths` after `startDate`.
*
- * @param startDate A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param numMonths A column of the number of months to add to `startDate`, can be negative to
- * subtract months
- * @return A date, or null if `startDate` was a string that could not be cast to a date
+ * @param startDate
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param numMonths
+ * A column of the number of months to add to `startDate`, can be negative to subtract months
+ * @return
+ * A date, or null if `startDate` was a string that could not be cast to a date
* @group datetime_funcs
* @since 3.0.0
*/
@@ -4796,8 +4892,8 @@ object functions {
Column.fn("add_months", startDate, numMonths)
/**
- * Returns the current date at the start of query evaluation as a date column.
- * All calls of current_date within the same query return the same value.
+ * Returns the current date at the start of query evaluation as a date column. All calls of
+ * current_date within the same query return the same value.
*
* @group datetime_funcs
* @since 3.5.0
@@ -4805,8 +4901,8 @@ object functions {
def curdate(): Column = Column.fn("curdate")
/**
- * Returns the current date at the start of query evaluation as a date column.
- * All calls of current_date within the same query return the same value.
+ * Returns the current date at the start of query evaluation as a date column. All calls of
+ * current_date within the same query return the same value.
*
* @group datetime_funcs
* @since 1.5.0
@@ -4822,8 +4918,8 @@ object functions {
def current_timezone(): Column = Column.fn("current_timezone")
/**
- * Returns the current timestamp at the start of query evaluation as a timestamp column.
- * All calls of current_timestamp within the same query return the same value.
+ * Returns the current timestamp at the start of query evaluation as a timestamp column. All
+ * calls of current_timestamp within the same query return the same value.
*
* @group datetime_funcs
* @since 1.5.0
@@ -4839,9 +4935,9 @@ object functions {
def now(): Column = Column.fn("now")
/**
- * Returns the current timestamp without time zone at the start of query evaluation
- * as a timestamp without time zone column.
- * All calls of localtimestamp within the same query return the same value.
+ * Returns the current timestamp without time zone at the start of query evaluation as a
+ * timestamp without time zone column. All calls of localtimestamp within the same query return
+ * the same value.
*
* @group datetime_funcs
* @since 3.3.0
@@ -4852,17 +4948,21 @@ object functions {
* Converts a date/timestamp/string to a value of string in the format specified by the date
* format given by the second argument.
*
- * See
- * Datetime Patterns
- * for valid date and time format patterns
- *
- * @param dateExpr A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param format A pattern `dd.MM.yyyy` would return a string like `18.03.1993`
- * @return A string, or null if `dateExpr` was a string that could not be cast to a timestamp
- * @note Use specialized functions like [[year]] whenever possible as they benefit from a
- * specialized implementation.
- * @throws IllegalArgumentException if the `format` pattern is invalid
+ * See Datetime
+ * Patterns for valid date and time format patterns
+ *
+ * @param dateExpr
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param format
+ * A pattern `dd.MM.yyyy` would return a string like `18.03.1993`
+ * @return
+ * A string, or null if `dateExpr` was a string that could not be cast to a timestamp
+ * @note
+ * Use specialized functions like [[year]] whenever possible as they benefit from a
+ * specialized implementation.
+ * @throws IllegalArgumentException
+ * if the `format` pattern is invalid
* @group datetime_funcs
* @since 1.5.0
*/
@@ -4872,10 +4972,13 @@ object functions {
/**
* Returns the date that is `days` days after `start`
*
- * @param start A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param days The number of days to add to `start`, can be negative to subtract days
- * @return A date, or null if `start` was a string that could not be cast to a date
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param days
+ * The number of days to add to `start`, can be negative to subtract days
+ * @return
+ * A date, or null if `start` was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -4884,10 +4987,13 @@ object functions {
/**
* Returns the date that is `days` days after `start`
*
- * @param start A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param days A column of the number of days to add to `start`, can be negative to subtract days
- * @return A date, or null if `start` was a string that could not be cast to a date
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param days
+ * A column of the number of days to add to `start`, can be negative to subtract days
+ * @return
+ * A date, or null if `start` was a string that could not be cast to a date
* @group datetime_funcs
* @since 3.0.0
*/
@@ -4896,10 +5002,13 @@ object functions {
/**
* Returns the date that is `days` days after `start`
*
- * @param start A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param days A column of the number of days to add to `start`, can be negative to subtract days
- * @return A date, or null if `start` was a string that could not be cast to a date
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param days
+ * A column of the number of days to add to `start`, can be negative to subtract days
+ * @return
+ * A date, or null if `start` was a string that could not be cast to a date
* @group datetime_funcs
* @since 3.5.0
*/
@@ -4908,10 +5017,13 @@ object functions {
/**
* Returns the date that is `days` days before `start`
*
- * @param start A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param days The number of days to subtract from `start`, can be negative to add days
- * @return A date, or null if `start` was a string that could not be cast to a date
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param days
+ * The number of days to subtract from `start`, can be negative to add days
+ * @return
+ * A date, or null if `start` was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -4920,11 +5032,13 @@ object functions {
/**
* Returns the date that is `days` days before `start`
*
- * @param start A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param days A column of the number of days to subtract from `start`, can be negative to add
- * days
- * @return A date, or null if `start` was a string that could not be cast to a date
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param days
+ * A column of the number of days to subtract from `start`, can be negative to add days
+ * @return
+ * A date, or null if `start` was a string that could not be cast to a date
* @group datetime_funcs
* @since 3.0.0
*/
@@ -4940,12 +5054,15 @@ object functions {
* // returns 1
* }}}
*
- * @param end A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param start A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @return An integer, or null if either `end` or `start` were strings that could not be cast to
- * a date. Negative if `end` is before `start`
+ * @param end
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @return
+ * An integer, or null if either `end` or `start` were strings that could not be cast to a
+ * date. Negative if `end` is before `start`
* @group datetime_funcs
* @since 1.5.0
*/
@@ -4960,12 +5077,15 @@ object functions {
* // returns 1
* }}}
*
- * @param end A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param start A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @return An integer, or null if either `end` or `start` were strings that could not be cast to
- * a date. Negative if `end` is before `start`
+ * @param end
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @return
+ * An integer, or null if either `end` or `start` were strings that could not be cast to a
+ * date. Negative if `end` is before `start`
* @group datetime_funcs
* @since 3.5.0
*/
@@ -4981,7 +5101,8 @@ object functions {
/**
* Extracts the year as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -4989,7 +5110,8 @@ object functions {
/**
* Extracts the quarter as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -4997,16 +5119,18 @@ object functions {
/**
* Extracts the month as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
def month(e: Column): Column = Column.fn("month", e)
/**
- * Extracts the day of the week as an integer from a given date/timestamp/string.
- * Ranges from 1 for a Sunday through to 7 for a Saturday
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * Extracts the day of the week as an integer from a given date/timestamp/string. Ranges from 1
+ * for a Sunday through to 7 for a Saturday
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 2.3.0
*/
@@ -5014,7 +5138,8 @@ object functions {
/**
* Extracts the day of the month as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5022,7 +5147,8 @@ object functions {
/**
* Extracts the day of the month as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 3.5.0
*/
@@ -5030,7 +5156,8 @@ object functions {
/**
* Extracts the day of the year as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5038,7 +5165,8 @@ object functions {
/**
* Extracts the hours as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5047,9 +5175,12 @@ object functions {
/**
* Extracts a part of the date/timestamp or interval source.
*
- * @param field selects which part of the source should be extracted.
- * @param source a date/timestamp or interval column from where `field` should be extracted.
- * @return a part of the date/timestamp or interval source
+ * @param field
+ * selects which part of the source should be extracted.
+ * @param source
+ * a date/timestamp or interval column from where `field` should be extracted.
+ * @return
+ * a part of the date/timestamp or interval source
* @group datetime_funcs
* @since 3.5.0
*/
@@ -5060,10 +5191,13 @@ object functions {
/**
* Extracts a part of the date/timestamp or interval source.
*
- * @param field selects which part of the source should be extracted, and supported string values
- * are as same as the fields of the equivalent function `extract`.
- * @param source a date/timestamp or interval column from where `field` should be extracted.
- * @return a part of the date/timestamp or interval source
+ * @param field
+ * selects which part of the source should be extracted, and supported string values are as
+ * same as the fields of the equivalent function `extract`.
+ * @param source
+ * a date/timestamp or interval column from where `field` should be extracted.
+ * @return
+ * a part of the date/timestamp or interval source
* @group datetime_funcs
* @since 3.5.0
*/
@@ -5074,10 +5208,13 @@ object functions {
/**
* Extracts a part of the date/timestamp or interval source.
*
- * @param field selects which part of the source should be extracted, and supported string values
- * are as same as the fields of the equivalent function `EXTRACT`.
- * @param source a date/timestamp or interval column from where `field` should be extracted.
- * @return a part of the date/timestamp or interval source
+ * @param field
+ * selects which part of the source should be extracted, and supported string values are as
+ * same as the fields of the equivalent function `EXTRACT`.
+ * @param source
+ * a date/timestamp or interval column from where `field` should be extracted.
+ * @return
+ * a part of the date/timestamp or interval source
* @group datetime_funcs
* @since 3.5.0
*/
@@ -5086,13 +5223,14 @@ object functions {
}
/**
- * Returns the last day of the month which the given date belongs to.
- * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the
- * month in July 2015.
+ * Returns the last day of the month which the given date belongs to. For example, input
+ * "2015-07-27" returns "2015-07-31" since July 31 is the last day of the month in July 2015.
*
- * @param e A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @return A date, or null if the input was a string that could not be cast to a date
+ * @param e
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @return
+ * A date, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5100,7 +5238,8 @@ object functions {
/**
* Extracts the minutes as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5115,7 +5254,8 @@ object functions {
def weekday(e: Column): Column = Column.fn("weekday", e)
/**
- * @return A date created from year, month and day fields.
+ * @return
+ * A date created from year, month and day fields.
* @group datetime_funcs
* @since 3.3.0
*/
@@ -5126,7 +5266,8 @@ object functions {
* Returns number of months between dates `start` and `end`.
*
* A whole number is returned if both inputs have the same day of month or both are the last day
- * of their respective months. Otherwise, the difference is calculated assuming 31 days per month.
+ * of their respective months. Otherwise, the difference is calculated assuming 31 days per
+ * month.
*
* For example:
* {{{
@@ -5135,12 +5276,15 @@ object functions {
* months_between("2017-06-01", "2017-06-16 12:00:00") // returns -0.5
* }}}
*
- * @param end A date, timestamp or string. If a string, the data must be in a format that can
- * be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param start A date, timestamp or string. If a string, the data must be in a format that can
- * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @return A double, or null if either `end` or `start` were strings that could not be cast to a
- * timestamp. Negative if `end` is before `start`
+ * @param end
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param start
+ * A date, timestamp or string. If a string, the data must be in a format that can cast to a
+ * timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @return
+ * A double, or null if either `end` or `start` were strings that could not be cast to a
+ * timestamp. Negative if `end` is before `start`
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5163,11 +5307,14 @@ object functions {
* For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first
* Sunday after 2015-07-27.
*
- * @param date A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param dayOfWeek Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"
- * @return A date, or null if `date` was a string that could not be cast to a date or if
- * `dayOfWeek` was an invalid value
+ * @param date
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param dayOfWeek
+ * Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"
+ * @return
+ * A date, or null if `date` was a string that could not be cast to a date or if `dayOfWeek`
+ * was an invalid value
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5180,12 +5327,15 @@ object functions {
* For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first
* Sunday after 2015-07-27.
*
- * @param date A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param dayOfWeek A column of the day of week. Case insensitive, and accepts: "Mon", "Tue",
- * "Wed", "Thu", "Fri", "Sat", "Sun"
- * @return A date, or null if `date` was a string that could not be cast to a date or if
- * `dayOfWeek` was an invalid value
+ * @param date
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param dayOfWeek
+ * A column of the day of week. Case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu",
+ * "Fri", "Sat", "Sun"
+ * @return
+ * A date, or null if `date` was a string that could not be cast to a date or if `dayOfWeek`
+ * was an invalid value
* @group datetime_funcs
* @since 3.2.0
*/
@@ -5194,7 +5344,8 @@ object functions {
/**
* Extracts the seconds as an integer from a given date/timestamp/string.
- * @return An integer, or null if the input was a string that could not be cast to a timestamp
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a timestamp
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5206,7 +5357,8 @@ object functions {
* A week is considered to start on a Monday and week 1 is the first week with more than 3 days,
* as defined by ISO 8601
*
- * @return An integer, or null if the input was a string that could not be cast to a date
+ * @return
+ * An integer, or null if the input was a string that could not be cast to a date
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5214,12 +5366,14 @@ object functions {
/**
* Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
- * representing the timestamp of that moment in the current system time zone in the
- * yyyy-MM-dd HH:mm:ss format.
- *
- * @param ut A number of a type that is castable to a long, such as string or integer. Can be
- * negative for timestamps before the unix epoch
- * @return A string, or null if the input was a string that could not be cast to a long
+ * representing the timestamp of that moment in the current system time zone in the yyyy-MM-dd
+ * HH:mm:ss format.
+ *
+ * @param ut
+ * A number of a type that is castable to a long, such as string or integer. Can be negative
+ * for timestamps before the unix epoch
+ * @return
+ * A string, or null if the input was a string that could not be cast to a long
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5230,15 +5384,17 @@ object functions {
* representing the timestamp of that moment in the current system time zone in the given
* format.
*
- * See
- * Datetime Patterns
- * for valid date and time format patterns
- *
- * @param ut A number of a type that is castable to a long, such as string or integer. Can be
- * negative for timestamps before the unix epoch
- * @param f A date time pattern that the input will be formatted to
- * @return A string, or null if `ut` was a string that could not be cast to a long or `f` was
- * an invalid date time pattern
+ * See Datetime
+ * Patterns for valid date and time format patterns
+ *
+ * @param ut
+ * A number of a type that is castable to a long, such as string or integer. Can be negative
+ * for timestamps before the unix epoch
+ * @param f
+ * A date time pattern that the input will be formatted to
+ * @return
+ * A string, or null if `ut` was a string that could not be cast to a long or `f` was an
+ * invalid date time pattern
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5248,8 +5404,9 @@ object functions {
/**
* Returns the current Unix timestamp (in seconds) as a long.
*
- * @note All calls of `unix_timestamp` within the same query return the same value
- * (i.e. the current timestamp is calculated at the start of query evaluation).
+ * @note
+ * All calls of `unix_timestamp` within the same query return the same value (i.e. the current
+ * timestamp is calculated at the start of query evaluation).
*
* @group datetime_funcs
* @since 1.5.0
@@ -5257,12 +5414,14 @@ object functions {
def unix_timestamp(): Column = unix_timestamp(current_timestamp())
/**
- * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds),
- * using the default timezone and the default locale.
+ * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), using the
+ * default timezone and the default locale.
*
- * @param s A date, timestamp or string. If a string, the data must be in the
- * `yyyy-MM-dd HH:mm:ss` format
- * @return A long, or null if the input was a string not of the correct format
+ * @param s
+ * A date, timestamp or string. If a string, the data must be in the `yyyy-MM-dd HH:mm:ss`
+ * format
+ * @return
+ * A long, or null if the input was a string not of the correct format
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5271,15 +5430,17 @@ object functions {
/**
* Converts time string with given pattern to Unix timestamp (in seconds).
*
- * See
- * Datetime Patterns
- * for valid date and time format patterns
- *
- * @param s A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param p A date time pattern detailing the format of `s` when `s` is a string
- * @return A long, or null if `s` was a string that could not be cast to a date or `p` was
- * an invalid format
+ * See Datetime
+ * Patterns for valid date and time format patterns
+ *
+ * @param s
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param p
+ * A date time pattern detailing the format of `s` when `s` is a string
+ * @return
+ * A long, or null if `s` was a string that could not be cast to a date or `p` was an invalid
+ * format
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5289,9 +5450,11 @@ object functions {
/**
* Converts to a timestamp by casting rules to `TimestampType`.
*
- * @param s A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @return A timestamp, or null if the input was a string that could not be cast to a timestamp
+ * @param s
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @return
+ * A timestamp, or null if the input was a string that could not be cast to a timestamp
* @group datetime_funcs
* @since 2.2.0
*/
@@ -5300,15 +5463,17 @@ object functions {
/**
* Converts time string with the given pattern to timestamp.
*
- * See
- * Datetime Patterns
- * for valid date and time format patterns
- *
- * @param s A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param fmt A date time pattern detailing the format of `s` when `s` is a string
- * @return A timestamp, or null if `s` was a string that could not be cast to a timestamp or
- * `fmt` was an invalid format
+ * See Datetime
+ * Patterns for valid date and time format patterns
+ *
+ * @param s
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param fmt
+ * A date time pattern detailing the format of `s` when `s` is a string
+ * @return
+ * A timestamp, or null if `s` was a string that could not be cast to a timestamp or `fmt` was
+ * an invalid format
* @group datetime_funcs
* @since 2.2.0
*/
@@ -5326,9 +5491,9 @@ object functions {
Column.fn("try_to_timestamp", s, format)
/**
- * Parses the `s` to a timestamp. The function always returns null on an invalid
- * input with`/`without ANSI SQL mode enabled. It follows casting rules to a timestamp. The
- * result data type is consistent with the value of configuration `spark.sql.timestampType`.
+ * Parses the `s` to a timestamp. The function always returns null on an invalid input
+ * with`/`without ANSI SQL mode enabled. It follows casting rules to a timestamp. The result
+ * data type is consistent with the value of configuration `spark.sql.timestampType`.
*
* @group datetime_funcs
* @since 3.5.0
@@ -5346,15 +5511,17 @@ object functions {
/**
* Converts the column into a `DateType` with a specified format
*
- * See
- * Datetime Patterns
- * for valid date and time format patterns
- *
- * @param e A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param fmt A date time pattern detailing the format of `e` when `e`is a string
- * @return A date, or null if `e` was a string that could not be cast to a date or `fmt` was an
- * invalid format
+ * See Datetime
+ * Patterns for valid date and time format patterns
+ *
+ * @param e
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param fmt
+ * A date time pattern detailing the format of `e` when `e`is a string
+ * @return
+ * A date, or null if `e` was a string that could not be cast to a date or `fmt` was an
+ * invalid format
* @group datetime_funcs
* @since 2.2.0
*/
@@ -5377,8 +5544,8 @@ object functions {
def unix_micros(e: Column): Column = Column.fn("unix_micros", e)
/**
- * Returns the number of milliseconds since 1970-01-01 00:00:00 UTC.
- * Truncates higher levels of precision.
+ * Returns the number of milliseconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of
+ * precision.
*
* @group datetime_funcs
* @since 3.5.0
@@ -5386,8 +5553,8 @@ object functions {
def unix_millis(e: Column): Column = Column.fn("unix_millis", e)
/**
- * Returns the number of seconds since 1970-01-01 00:00:00 UTC.
- * Truncates higher levels of precision.
+ * Returns the number of seconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of
+ * precision.
*
* @group datetime_funcs
* @since 3.5.0
@@ -5399,14 +5566,16 @@ object functions {
*
* For example, `trunc("2018-11-19 12:01:19", "year")` returns 2018-01-01
*
- * @param date A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param format: 'year', 'yyyy', 'yy' to truncate by year,
- * or 'month', 'mon', 'mm' to truncate by month
- * Other options are: 'week', 'quarter'
+ * @param date
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a date, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param format:
+ * 'year', 'yyyy', 'yy' to truncate by year, or 'month', 'mon', 'mm' to truncate by month
+ * Other options are: 'week', 'quarter'
*
- * @return A date, or null if `date` was a string that could not be cast to a date or `format`
- * was an invalid value
+ * @return
+ * A date, or null if `date` was a string that could not be cast to a date or `format` was an
+ * invalid value
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5417,15 +5586,16 @@ object functions {
*
* For example, `date_trunc("year", "2018-11-19 12:01:19")` returns 2018-01-01 00:00:00
*
- * @param format: 'year', 'yyyy', 'yy' to truncate by year,
- * 'month', 'mon', 'mm' to truncate by month,
- * 'day', 'dd' to truncate by day,
- * Other options are:
- * 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'week', 'quarter'
- * @param timestamp A date, timestamp or string. If a string, the data must be in a format that
- * can be cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @return A timestamp, or null if `timestamp` was a string that could not be cast to a timestamp
- * or `format` was an invalid value
+ * @param format:
+ * 'year', 'yyyy', 'yy' to truncate by year, 'month', 'mon', 'mm' to truncate by month, 'day',
+ * 'dd' to truncate by day, Other options are: 'microsecond', 'millisecond', 'second',
+ * 'minute', 'hour', 'week', 'quarter'
+ * @param timestamp
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @return
+ * A timestamp, or null if `timestamp` was a string that could not be cast to a timestamp or
+ * `format` was an invalid value
* @group datetime_funcs
* @since 2.3.0
*/
@@ -5434,19 +5604,21 @@ object functions {
/**
* Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders
- * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield
- * '2017-07-14 03:40:00.0'.
- *
- * @param ts A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param tz A string detailing the time zone ID that the input should be adjusted to. It should
- * be in the format of either region-based zone IDs or zone offsets. Region IDs must
- * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in
- * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are
- * supported as aliases of '+00:00'. Other short names are not recommended to use
- * because they can be ambiguous.
- * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or
- * `tz` was an invalid value
+ * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14
+ * 03:40:00.0'.
+ *
+ * @param ts
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param tz
+ * A string detailing the time zone ID that the input should be adjusted to. It should be in
+ * the format of either region-based zone IDs or zone offsets. Region IDs must have the form
+ * 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in the format
+ * '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases
+ * of '+00:00'. Other short names are not recommended to use because they can be ambiguous.
+ * @return
+ * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was
+ * an invalid value
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5454,8 +5626,8 @@ object functions {
/**
* Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a time in UTC, and renders
- * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield
- * '2017-07-14 03:40:00.0'.
+ * that time as a timestamp in the given time zone. For example, 'GMT+1' would yield '2017-07-14
+ * 03:40:00.0'.
* @group datetime_funcs
* @since 2.4.0
*/
@@ -5467,16 +5639,18 @@ object functions {
* zone, and renders that time as a timestamp in UTC. For example, 'GMT+1' would yield
* '2017-07-14 01:40:00.0'.
*
- * @param ts A date, timestamp or string. If a string, the data must be in a format that can be
- * cast to a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
- * @param tz A string detailing the time zone ID that the input should be adjusted to. It should
- * be in the format of either region-based zone IDs or zone offsets. Region IDs must
- * have the form 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in
- * the format '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are
- * supported as aliases of '+00:00'. Other short names are not recommended to use
- * because they can be ambiguous.
- * @return A timestamp, or null if `ts` was a string that could not be cast to a timestamp or
- * `tz` was an invalid value
+ * @param ts
+ * A date, timestamp or string. If a string, the data must be in a format that can be cast to
+ * a timestamp, such as `yyyy-MM-dd` or `yyyy-MM-dd HH:mm:ss.SSSS`
+ * @param tz
+ * A string detailing the time zone ID that the input should be adjusted to. It should be in
+ * the format of either region-based zone IDs or zone offsets. Region IDs must have the form
+ * 'area/city', such as 'America/Los_Angeles'. Zone offsets must be in the format
+ * '(+|-)HH:mm', for example '-08:00' or '+01:00'. Also 'UTC' and 'Z' are supported as aliases
+ * of '+00:00'. Other short names are not recommended to use because they can be ambiguous.
+ * @return
+ * A timestamp, or null if `ts` was a string that could not be cast to a timestamp or `tz` was
+ * an invalid value
* @group datetime_funcs
* @since 1.5.0
*/
@@ -5495,8 +5669,8 @@ object functions {
* Bucketize rows into one or more time windows given a timestamp specifying column. Window
* starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
* [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
- * the order of months are not supported. The following example takes the average stock price for
- * a one minute window every 10 seconds starting 5 seconds after the hour:
+ * the order of months are not supported. The following example takes the average stock price
+ * for a one minute window every 10 seconds starting 5 seconds after the hour:
*
* {{{
* val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType
@@ -5515,23 +5689,23 @@ object functions {
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
*
- * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
- * The time column must be of TimestampType or TimestampNTZType.
- * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`,
- * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for
- * valid duration identifiers. Note that the duration is a fixed length of
- * time, and does not vary over time according to a calendar. For example,
- * `1 day` always means 86,400,000 milliseconds, not a calendar day.
- * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`.
- * A new window will be generated every `slideDuration`. Must be less than
- * or equal to the `windowDuration`. Check
- * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration
- * identifiers. This duration is likewise absolute, and does not vary
- * according to a calendar.
- * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start
- * window intervals. For example, in order to have hourly tumbling windows that
- * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide
- * `startTime` as `15 minutes`.
+ * @param timeColumn
+ * The column or the expression to use as the timestamp for windowing by time. The time column
+ * must be of TimestampType or TimestampNTZType.
+ * @param windowDuration
+ * A string specifying the width of the window, e.g. `10 minutes`, `1 second`. Check
+ * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. Note that
+ * the duration is a fixed length of time, and does not vary over time according to a
+ * calendar. For example, `1 day` always means 86,400,000 milliseconds, not a calendar day.
+ * @param slideDuration
+ * A string specifying the sliding interval of the window, e.g. `1 minute`. A new window will
+ * be generated every `slideDuration`. Must be less than or equal to the `windowDuration`.
+ * Check `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. This
+ * duration is likewise absolute, and does not vary according to a calendar.
+ * @param startTime
+ * The offset with respect to 1970-01-01 00:00:00 UTC with which to start window intervals.
+ * For example, in order to have hourly tumbling windows that start 15 minutes past the hour,
+ * e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`.
*
* @group datetime_funcs
* @since 2.0.0
@@ -5547,8 +5721,9 @@ object functions {
* Bucketize rows into one or more time windows given a timestamp specifying column. Window
* starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
* [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
- * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC.
- * The following example takes the average stock price for a one minute window every 10 seconds:
+ * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00
+ * UTC. The following example takes the average stock price for a one minute window every 10
+ * seconds:
*
* {{{
* val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType
@@ -5567,19 +5742,19 @@ object functions {
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
*
- * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
- * The time column must be of TimestampType or TimestampNTZType.
- * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`,
- * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for
- * valid duration identifiers. Note that the duration is a fixed length of
- * time, and does not vary over time according to a calendar. For example,
- * `1 day` always means 86,400,000 milliseconds, not a calendar day.
- * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`.
- * A new window will be generated every `slideDuration`. Must be less than
- * or equal to the `windowDuration`. Check
- * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration
- * identifiers. This duration is likewise absolute, and does not vary
- * according to a calendar.
+ * @param timeColumn
+ * The column or the expression to use as the timestamp for windowing by time. The time column
+ * must be of TimestampType or TimestampNTZType.
+ * @param windowDuration
+ * A string specifying the width of the window, e.g. `10 minutes`, `1 second`. Check
+ * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. Note that
+ * the duration is a fixed length of time, and does not vary over time according to a
+ * calendar. For example, `1 day` always means 86,400,000 milliseconds, not a calendar day.
+ * @param slideDuration
+ * A string specifying the sliding interval of the window, e.g. `1 minute`. A new window will
+ * be generated every `slideDuration`. Must be less than or equal to the `windowDuration`.
+ * Check `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers. This
+ * duration is likewise absolute, and does not vary according to a calendar.
*
* @group datetime_funcs
* @since 2.0.0
@@ -5589,11 +5764,11 @@ object functions {
}
/**
- * Generates tumbling time windows given a timestamp specifying column. Window
- * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
- * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
- * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC.
- * The following example takes the average stock price for a one minute tumbling window:
+ * Generates tumbling time windows given a timestamp specifying column. Window starts are
+ * inclusive but the window ends are exclusive, e.g. 12:05 will be in the window [12:05,12:10)
+ * but not in [12:00,12:05). Windows can support microsecond precision. Windows in the order of
+ * months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. The
+ * following example takes the average stock price for a one minute tumbling window:
*
* {{{
* val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType
@@ -5612,11 +5787,12 @@ object functions {
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
*
- * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
- * The time column must be of TimestampType or TimestampNTZType.
- * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`,
- * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for
- * valid duration identifiers.
+ * @param timeColumn
+ * The column or the expression to use as the timestamp for windowing by time. The time column
+ * must be of TimestampType or TimestampNTZType.
+ * @param windowDuration
+ * A string specifying the width of the window, e.g. `10 minutes`, `1 second`. Check
+ * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers.
*
* @group datetime_funcs
* @since 2.0.0
@@ -5632,8 +5808,9 @@ object functions {
* inclusive and end is exclusive. Since event time can support microsecond precision,
* window_time(window) = window.end - 1 microsecond.
*
- * @param windowColumn The window column (typically produced by window aggregation) of type
- * StructType { start: Timestamp, end: Timestamp }
+ * @param windowColumn
+ * The window column (typically produced by window aggregation) of type StructType { start:
+ * Timestamp, end: Timestamp }
*
* @group datetime_funcs
* @since 3.4.0
@@ -5644,10 +5821,9 @@ object functions {
* Generates session window given a timestamp specifying column.
*
* Session window is one of dynamic windows, which means the length of window is varying
- * according to the given inputs. The length of session window is defined as "the timestamp
- * of latest input of the session + gap duration", so when the new inputs are bound to the
- * current session window, the end time of session window can be expanded according to the new
- * inputs.
+ * according to the given inputs. The length of session window is defined as "the timestamp of
+ * latest input of the session + gap duration", so when the new inputs are bound to the current
+ * session window, the end time of session window can be expanded according to the new inputs.
*
* Windows can support microsecond precision. gapDuration in the order of months are not
* supported.
@@ -5655,11 +5831,12 @@ object functions {
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
*
- * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
- * The time column must be of TimestampType or TimestampNTZType.
- * @param gapDuration A string specifying the timeout of the session, e.g. `10 minutes`,
- * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for
- * valid duration identifiers.
+ * @param timeColumn
+ * The column or the expression to use as the timestamp for windowing by time. The time column
+ * must be of TimestampType or TimestampNTZType.
+ * @param gapDuration
+ * A string specifying the timeout of the session, e.g. `10 minutes`, `1 second`. Check
+ * `org.apache.spark.unsafe.types.CalendarInterval` for valid duration identifiers.
*
* @group datetime_funcs
* @since 3.2.0
@@ -5671,17 +5848,17 @@ object functions {
* Generates session window given a timestamp specifying column.
*
* Session window is one of dynamic windows, which means the length of window is varying
- * according to the given inputs. For static gap duration, the length of session window
- * is defined as "the timestamp of latest input of the session + gap duration", so when
- * the new inputs are bound to the current session window, the end time of session window
- * can be expanded according to the new inputs.
- *
- * Besides a static gap duration value, users can also provide an expression to specify
- * gap duration dynamically based on the input row. With dynamic gap duration, the closing
- * of a session window does not depend on the latest input anymore. A session window's range
- * is the union of all events' ranges which are determined by event start time and evaluated
- * gap duration during the query execution. Note that the rows with negative or zero gap
- * duration will be filtered out from the aggregation.
+ * according to the given inputs. For static gap duration, the length of session window is
+ * defined as "the timestamp of latest input of the session + gap duration", so when the new
+ * inputs are bound to the current session window, the end time of session window can be
+ * expanded according to the new inputs.
+ *
+ * Besides a static gap duration value, users can also provide an expression to specify gap
+ * duration dynamically based on the input row. With dynamic gap duration, the closing of a
+ * session window does not depend on the latest input anymore. A session window's range is the
+ * union of all events' ranges which are determined by event start time and evaluated gap
+ * duration during the query execution. Note that the rows with negative or zero gap duration
+ * will be filtered out from the aggregation.
*
* Windows can support microsecond precision. gapDuration in the order of months are not
* supported.
@@ -5689,11 +5866,13 @@ object functions {
* For a streaming query, you may use the function `current_timestamp` to generate windows on
* processing time.
*
- * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
- * The time column must be of TimestampType or TimestampNTZType.
- * @param gapDuration A column specifying the timeout of the session. It could be static value,
- * e.g. `10 minutes`, `1 second`, or an expression/UDF that specifies gap
- * duration dynamically based on the input row.
+ * @param timeColumn
+ * The column or the expression to use as the timestamp for windowing by time. The time column
+ * must be of TimestampType or TimestampNTZType.
+ * @param gapDuration
+ * A column specifying the timeout of the session. It could be static value, e.g. `10
+ * minutes`, `1 second`, or an expression/UDF that specifies gap duration dynamically based on
+ * the input row.
*
* @group datetime_funcs
* @since 3.2.0
@@ -5702,8 +5881,7 @@ object functions {
Column.fn("session_window", timeColumn, gapDuration)
/**
- * Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z)
- * to a timestamp.
+ * Converts the number of seconds from the Unix epoch (1970-01-01T00:00:00Z) to a timestamp.
* @group datetime_funcs
* @since 3.1.0
*/
@@ -5726,8 +5904,8 @@ object functions {
def timestamp_micros(e: Column): Column = Column.fn("timestamp_micros", e)
/**
- * Gets the difference between the timestamps in the specified units by truncating
- * the fraction part.
+ * Gets the difference between the timestamps in the specified units by truncating the fraction
+ * part.
*
* @group datetime_funcs
* @since 4.0.0
@@ -5745,8 +5923,8 @@ object functions {
Column.internalFn("timestampadd", lit(unit), quantity, ts)
/**
- * Parses the `timestamp` expression with the `format` expression
- * to a timestamp without time zone. Returns null with invalid input.
+ * Parses the `timestamp` expression with the `format` expression to a timestamp without time
+ * zone. Returns null with invalid input.
*
* @group datetime_funcs
* @since 3.5.0
@@ -5765,8 +5943,8 @@ object functions {
Column.fn("to_timestamp_ltz", timestamp)
/**
- * Parses the `timestamp_str` expression with the `format` expression
- * to a timestamp without time zone. Returns null with invalid input.
+ * Parses the `timestamp_str` expression with the `format` expression to a timestamp without
+ * time zone. Returns null with invalid input.
*
* @group datetime_funcs
* @since 3.5.0
@@ -5855,9 +6033,12 @@ object functions {
* Returns an array containing all the elements in `x` from index `start` (or starting from the
* end if `start` is negative) with the specified `length`.
*
- * @param x the array column to be sliced
- * @param start the starting index
- * @param length the length of the slice
+ * @param x
+ * the array column to be sliced
+ * @param start
+ * the starting index
+ * @param length
+ * the length of the slice
*
* @group array_funcs
* @since 2.4.0
@@ -5869,9 +6050,12 @@ object functions {
* Returns an array containing all the elements in `x` from index `start` (or starting from the
* end if `start` is negative) with the specified `length`.
*
- * @param x the array column to be sliced
- * @param start the starting index
- * @param length the length of the slice
+ * @param x
+ * the array column to be sliced
+ * @param start
+ * the starting index
+ * @param length
+ * the length of the slice
*
* @group array_funcs
* @since 3.1.0
@@ -5897,10 +6081,11 @@ object functions {
Column.fn("array_join", column, lit(delimiter))
/**
- * Concatenates multiple input columns together into a single column.
- * The function works with strings, binary and compatible array columns.
+ * Concatenates multiple input columns together into a single column. The function works with
+ * strings, binary and compatible array columns.
*
- * @note Returns null if any of the input columns are null.
+ * @note
+ * Returns null if any of the input columns are null.
*
* @group collection_funcs
* @since 1.5.0
@@ -5909,11 +6094,12 @@ object functions {
def concat(exprs: Column*): Column = Column.fn("concat", exprs: _*)
/**
- * Locates the position of the first occurrence of the value in the given array as long.
- * Returns null if either of the arguments are null.
+ * Locates the position of the first occurrence of the value in the given array as long. Returns
+ * null if either of the arguments are null.
*
- * @note The position is not zero based, but 1 based index. Returns 0 if value
- * could not be found in array.
+ * @note
+ * The position is not zero based, but 1 based index. Returns 0 if value could not be found in
+ * array.
*
* @group array_funcs
* @since 2.4.0
@@ -5922,8 +6108,8 @@ object functions {
Column.fn("array_position", column, lit(value))
/**
- * Returns element of array at given index in value if column is array. Returns value for
- * the given key in value if column is map.
+ * Returns element of array at given index in value if column is array. Returns value for the
+ * given key in value if column is map.
*
* @group collection_funcs
* @since 2.4.0
@@ -5945,8 +6131,8 @@ object functions {
Column.fn("try_element_at", column, value)
/**
- * Returns element of array at given (0-based) index. If the index points
- * outside of the array boundaries, then this function returns NULL.
+ * Returns element of array at given (0-based) index. If the index points outside of the array
+ * boundaries, then this function returns NULL.
*
* @group array_funcs
* @since 3.4.0
@@ -5955,8 +6141,8 @@ object functions {
/**
* Sorts the input array in ascending order. The elements of the input array must be orderable.
- * NaN is greater than any non-NaN elements for double/float type.
- * Null elements will be placed at the end of the returned array.
+ * NaN is greater than any non-NaN elements for double/float type. Null elements will be placed
+ * at the end of the returned array.
*
* @group collection_funcs
* @since 2.4.0
@@ -6010,8 +6196,8 @@ object functions {
def array_distinct(e: Column): Column = Column.fn("array_distinct", e)
/**
- * Returns an array of the elements in the intersection of the given two arrays,
- * without duplicates.
+ * Returns an array of the elements in the intersection of the given two arrays, without
+ * duplicates.
*
* @group array_funcs
* @since 2.4.0
@@ -6038,8 +6224,8 @@ object functions {
Column.fn("array_union", col1, col2)
/**
- * Returns an array of the elements in the first array but not in the second array,
- * without duplicates. The order of elements in the result is not determined
+ * Returns an array of the elements in the first array but not in the second array, without
+ * duplicates. The order of elements in the result is not determined
*
* @group array_funcs
* @since 2.4.0
@@ -6047,7 +6233,6 @@ object functions {
def array_except(col1: Column, col2: Column): Column =
Column.fn("array_except", col1, col2)
-
private def createLambda(f: Column => Column) = {
val x = internal.UnresolvedNamedLambdaVariable("x")
val function = f(Column(x)).node
@@ -6070,14 +6255,16 @@ object functions {
}
/**
- * Returns an array of elements after applying a transformation to each element
- * in the input array.
+ * Returns an array of elements after applying a transformation to each element in the input
+ * array.
* {{{
* df.select(transform(col("i"), x => x + 1))
* }}}
*
- * @param column the input array column
- * @param f col => transformed_col, the lambda function to transform the input column
+ * @param column
+ * the input array column
+ * @param f
+ * col => transformed_col, the lambda function to transform the input column
*
* @group collection_funcs
* @since 3.0.0
@@ -6086,15 +6273,17 @@ object functions {
Column.fn("transform", column, createLambda(f))
/**
- * Returns an array of elements after applying a transformation to each element
- * in the input array.
+ * Returns an array of elements after applying a transformation to each element in the input
+ * array.
* {{{
* df.select(transform(col("i"), (x, i) => x + i))
* }}}
*
- * @param column the input array column
- * @param f (col, index) => transformed_col, the lambda function to transform the input
- * column given the index. Indices start at 0.
+ * @param column
+ * the input array column
+ * @param f
+ * (col, index) => transformed_col, the lambda function to transform the input column given
+ * the index. Indices start at 0.
*
* @group collection_funcs
* @since 3.0.0
@@ -6108,8 +6297,10 @@ object functions {
* df.select(exists(col("i"), _ % 2 === 0))
* }}}
*
- * @param column the input array column
- * @param f col => predicate, the Boolean predicate to check the input column
+ * @param column
+ * the input array column
+ * @param f
+ * col => predicate, the Boolean predicate to check the input column
*
* @group collection_funcs
* @since 3.0.0
@@ -6123,8 +6314,10 @@ object functions {
* df.select(forall(col("i"), x => x % 2 === 0))
* }}}
*
- * @param column the input array column
- * @param f col => predicate, the Boolean predicate to check the input column
+ * @param column
+ * the input array column
+ * @param f
+ * col => predicate, the Boolean predicate to check the input column
*
* @group collection_funcs
* @since 3.0.0
@@ -6138,8 +6331,10 @@ object functions {
* df.select(filter(col("s"), x => x % 2 === 0))
* }}}
*
- * @param column the input array column
- * @param f col => predicate, the Boolean predicate to filter the input column
+ * @param column
+ * the input array column
+ * @param f
+ * col => predicate, the Boolean predicate to filter the input column
*
* @group collection_funcs
* @since 3.0.0
@@ -6153,9 +6348,11 @@ object functions {
* df.select(filter(col("s"), (x, i) => i % 2 === 0))
* }}}
*
- * @param column the input array column
- * @param f (col, index) => predicate, the Boolean predicate to filter the input column
- * given the index. Indices start at 0.
+ * @param column
+ * the input array column
+ * @param f
+ * (col, index) => predicate, the Boolean predicate to filter the input column given the
+ * index. Indices start at 0.
*
* @group collection_funcs
* @since 3.0.0
@@ -6164,19 +6361,23 @@ object functions {
Column.fn("filter", column, createLambda(f))
/**
- * Applies a binary operator to an initial state and all elements in the array,
- * and reduces this to a single state. The final state is converted into the final result
- * by applying a finish function.
+ * Applies a binary operator to an initial state and all elements in the array, and reduces this
+ * to a single state. The final state is converted into the final result by applying a finish
+ * function.
* {{{
* df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10))
* }}}
*
- * @param expr the input array column
- * @param initialValue the initial value
- * @param merge (combined_value, input_value) => combined_value, the merge function to merge
- * an input value to the combined_value
- * @param finish combined_value => final_value, the lambda function to convert the combined value
- * of all inputs to final result
+ * @param expr
+ * the input array column
+ * @param initialValue
+ * the initial value
+ * @param merge
+ * (combined_value, input_value) => combined_value, the merge function to merge an input value
+ * to the combined_value
+ * @param finish
+ * combined_value => final_value, the lambda function to convert the combined value of all
+ * inputs to final result
*
* @group collection_funcs
* @since 3.0.0
@@ -6189,16 +6390,19 @@ object functions {
Column.fn("aggregate", expr, initialValue, createLambda(merge), createLambda(finish))
/**
- * Applies a binary operator to an initial state and all elements in the array,
- * and reduces this to a single state.
+ * Applies a binary operator to an initial state and all elements in the array, and reduces this
+ * to a single state.
* {{{
* df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x))
* }}}
*
- * @param expr the input array column
- * @param initialValue the initial value
- * @param merge (combined_value, input_value) => combined_value, the merge function to merge
- * an input value to the combined_value
+ * @param expr
+ * the input array column
+ * @param initialValue
+ * the initial value
+ * @param merge
+ * (combined_value, input_value) => combined_value, the merge function to merge an input value
+ * to the combined_value
* @group collection_funcs
* @since 3.0.0
*/
@@ -6206,19 +6410,23 @@ object functions {
aggregate(expr, initialValue, merge, c => c)
/**
- * Applies a binary operator to an initial state and all elements in the array,
- * and reduces this to a single state. The final state is converted into the final result
- * by applying a finish function.
+ * Applies a binary operator to an initial state and all elements in the array, and reduces this
+ * to a single state. The final state is converted into the final result by applying a finish
+ * function.
* {{{
* df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10))
* }}}
*
- * @param expr the input array column
- * @param initialValue the initial value
- * @param merge (combined_value, input_value) => combined_value, the merge function to merge
- * an input value to the combined_value
- * @param finish combined_value => final_value, the lambda function to convert the combined value
- * of all inputs to final result
+ * @param expr
+ * the input array column
+ * @param initialValue
+ * the initial value
+ * @param merge
+ * (combined_value, input_value) => combined_value, the merge function to merge an input value
+ * to the combined_value
+ * @param finish
+ * combined_value => final_value, the lambda function to convert the combined value of all
+ * inputs to final result
*
* @group collection_funcs
* @since 3.5.0
@@ -6231,16 +6439,19 @@ object functions {
Column.fn("reduce", expr, initialValue, createLambda(merge), createLambda(finish))
/**
- * Applies a binary operator to an initial state and all elements in the array,
- * and reduces this to a single state.
+ * Applies a binary operator to an initial state and all elements in the array, and reduces this
+ * to a single state.
* {{{
* df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x))
* }}}
*
- * @param expr the input array column
- * @param initialValue the initial value
- * @param merge (combined_value, input_value) => combined_value, the merge function to merge
- * an input value to the combined_value
+ * @param expr
+ * the input array column
+ * @param initialValue
+ * the initial value
+ * @param merge
+ * (combined_value, input_value) => combined_value, the merge function to merge an input value
+ * to the combined_value
* @group collection_funcs
* @since 3.5.0
*/
@@ -6248,16 +6459,19 @@ object functions {
reduce(expr, initialValue, merge, c => c)
/**
- * Merge two given arrays, element-wise, into a single array using a function.
- * If one array is shorter, nulls are appended at the end to match the length of the longer
- * array, before applying the function.
+ * Merge two given arrays, element-wise, into a single array using a function. If one array is
+ * shorter, nulls are appended at the end to match the length of the longer array, before
+ * applying the function.
* {{{
* df.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y))
* }}}
*
- * @param left the left input array column
- * @param right the right input array column
- * @param f (lCol, rCol) => col, the lambda function to merge two input columns into one column
+ * @param left
+ * the left input array column
+ * @param right
+ * the right input array column
+ * @param f
+ * (lCol, rCol) => col, the lambda function to merge two input columns into one column
*
* @group collection_funcs
* @since 3.0.0
@@ -6266,14 +6480,16 @@ object functions {
Column.fn("zip_with", left, right, createLambda(f))
/**
- * Applies a function to every key-value pair in a map and returns
- * a map with the results of those applications as the new keys for the pairs.
+ * Applies a function to every key-value pair in a map and returns a map with the results of
+ * those applications as the new keys for the pairs.
* {{{
* df.select(transform_keys(col("i"), (k, v) => k + v))
* }}}
*
- * @param expr the input map column
- * @param f (key, value) => new_key, the lambda function to transform the key of input map column
+ * @param expr
+ * the input map column
+ * @param f
+ * (key, value) => new_key, the lambda function to transform the key of input map column
*
* @group collection_funcs
* @since 3.0.0
@@ -6282,15 +6498,16 @@ object functions {
Column.fn("transform_keys", expr, createLambda(f))
/**
- * Applies a function to every key-value pair in a map and returns
- * a map with the results of those applications as the new values for the pairs.
+ * Applies a function to every key-value pair in a map and returns a map with the results of
+ * those applications as the new values for the pairs.
* {{{
* df.select(transform_values(col("i"), (k, v) => k + v))
* }}}
*
- * @param expr the input map column
- * @param f (key, value) => new_value, the lambda function to transform the value of input map
- * column
+ * @param expr
+ * the input map column
+ * @param f
+ * (key, value) => new_value, the lambda function to transform the value of input map column
*
* @group collection_funcs
* @since 3.0.0
@@ -6304,8 +6521,10 @@ object functions {
* df.select(map_filter(col("m"), (k, v) => k * 10 === v))
* }}}
*
- * @param expr the input map column
- * @param f (key, value) => predicate, the Boolean predicate to filter the input map column
+ * @param expr
+ * the input map column
+ * @param f
+ * (key, value) => predicate, the Boolean predicate to filter the input map column
*
* @group collection_funcs
* @since 3.0.0
@@ -6319,9 +6538,12 @@ object functions {
* df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2))
* }}}
*
- * @param left the left input map column
- * @param right the right input map column
- * @param f (key, value1, value2) => new_value, the lambda function to merge the map values
+ * @param left
+ * the left input map column
+ * @param right
+ * the right input map column
+ * @param f
+ * (key, value1, value2) => new_value, the lambda function to merge the map values
*
* @group collection_funcs
* @since 3.0.0
@@ -6330,9 +6552,9 @@ object functions {
Column.fn("map_zip_with", left, right, createLambda(f))
/**
- * Creates a new row for each element in the given array or map column.
- * Uses the default column name `col` for elements in the array and
- * `key` and `value` for elements in the map unless specified otherwise.
+ * Creates a new row for each element in the given array or map column. Uses the default column
+ * name `col` for elements in the array and `key` and `value` for elements in the map unless
+ * specified otherwise.
*
* @group generator_funcs
* @since 1.3.0
@@ -6340,10 +6562,9 @@ object functions {
def explode(e: Column): Column = Column.fn("explode", e)
/**
- * Creates a new row for each element in the given array or map column.
- * Uses the default column name `col` for elements in the array and
- * `key` and `value` for elements in the map unless specified otherwise.
- * Unlike explode, if the array/map is null or empty then null is produced.
+ * Creates a new row for each element in the given array or map column. Uses the default column
+ * name `col` for elements in the array and `key` and `value` for elements in the map unless
+ * specified otherwise. Unlike explode, if the array/map is null or empty then null is produced.
*
* @group generator_funcs
* @since 2.2.0
@@ -6351,9 +6572,9 @@ object functions {
def explode_outer(e: Column): Column = Column.fn("explode_outer", e)
/**
- * Creates a new row for each element with position in the given array or map column.
- * Uses the default column name `pos` for position, and `col` for elements in the array
- * and `key` and `value` for elements in the map unless specified otherwise.
+ * Creates a new row for each element with position in the given array or map column. Uses the
+ * default column name `pos` for position, and `col` for elements in the array and `key` and
+ * `value` for elements in the map unless specified otherwise.
*
* @group generator_funcs
* @since 2.1.0
@@ -6361,17 +6582,17 @@ object functions {
def posexplode(e: Column): Column = Column.fn("posexplode", e)
/**
- * Creates a new row for each element with position in the given array or map column.
- * Uses the default column name `pos` for position, and `col` for elements in the array
- * and `key` and `value` for elements in the map unless specified otherwise.
- * Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
+ * Creates a new row for each element with position in the given array or map column. Uses the
+ * default column name `pos` for position, and `col` for elements in the array and `key` and
+ * `value` for elements in the map unless specified otherwise. Unlike posexplode, if the
+ * array/map is null or empty then the row (null, null) is produced.
*
* @group generator_funcs
* @since 2.2.0
*/
def posexplode_outer(e: Column): Column = Column.fn("posexplode_outer", e)
- /**
+ /**
* Creates a new row for each element in the given array of structs.
*
* @group generator_funcs
@@ -6380,8 +6601,8 @@ object functions {
def inline(e: Column): Column = Column.fn("inline", e)
/**
- * Creates a new row for each element in the given array of structs.
- * Unlike inline, if the array is null or empty then null is produced for each nested column.
+ * Creates a new row for each element in the given array of structs. Unlike inline, if the array
+ * is null or empty then null is produced for each nested column.
*
* @group generator_funcs
* @since 3.4.0
@@ -6415,14 +6636,15 @@ object functions {
* (Scala-specific) Parses a column containing a JSON string into a `StructType` with the
* specified schema. Returns `null`, in the case of an unparseable string.
*
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
- * @param options options to control how the json is parsed. Accepts the same options as the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
+ * @param options
+ * options to control how the json is parsed. Accepts the same options as the json data
+ * source. See Data
+ * Source Option in the version you use.
*
* @group json_funcs
* @since 2.1.0
@@ -6434,17 +6656,18 @@ object functions {
// scalastyle:off line.size.limit
/**
* (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
- * as keys type, `StructType` or `ArrayType` with the specified schema.
- * Returns `null`, in the case of an unparseable string.
- *
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
- * @param options options to control how the json is parsed. accepts the same options and the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
+ * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the
+ * case of an unparseable string.
+ *
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
+ * @param options
+ * options to control how the json is parsed. accepts the same options and the json data
+ * source. See Data
+ * Source Option in the version you use.
*
* @group json_funcs
* @since 2.2.0
@@ -6459,14 +6682,15 @@ object functions {
* (Java-specific) Parses a column containing a JSON string into a `StructType` with the
* specified schema. Returns `null`, in the case of an unparseable string.
*
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
- * @param options options to control how the json is parsed. accepts the same options and the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
+ * @param options
+ * options to control how the json is parsed. accepts the same options and the json data
+ * source. See Data
+ * Source Option in the version you use.
*
* @group json_funcs
* @since 2.1.0
@@ -6478,17 +6702,18 @@ object functions {
// scalastyle:off line.size.limit
/**
* (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
- * as keys type, `StructType` or `ArrayType` with the specified schema.
- * Returns `null`, in the case of an unparseable string.
- *
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
- * @param options options to control how the json is parsed. accepts the same options and the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
+ * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the
+ * case of an unparseable string.
+ *
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
+ * @param options
+ * options to control how the json is parsed. accepts the same options and the json data
+ * source. See Data
+ * Source Option in the version you use.
*
* @group json_funcs
* @since 2.2.0
@@ -6502,8 +6727,10 @@ object functions {
* Parses a column containing a JSON string into a `StructType` with the specified schema.
* Returns `null`, in the case of an unparseable string.
*
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
*
* @group json_funcs
* @since 2.1.0
@@ -6513,11 +6740,13 @@ object functions {
/**
* Parses a column containing a JSON string into a `MapType` with `StringType` as keys type,
- * `StructType` or `ArrayType` with the specified schema.
- * Returns `null`, in the case of an unparseable string.
+ * `StructType` or `ArrayType` with the specified schema. Returns `null`, in the case of an
+ * unparseable string.
*
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
*
* @group json_funcs
* @since 2.2.0
@@ -6528,17 +6757,18 @@ object functions {
// scalastyle:off line.size.limit
/**
* (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
- * as keys type, `StructType` or `ArrayType` with the specified schema.
- * Returns `null`, in the case of an unparseable string.
- *
- * @param e a string column containing JSON data.
- * @param schema the schema as a DDL-formatted string.
- * @param options options to control how the json is parsed. accepts the same options and the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
+ * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the
+ * case of an unparseable string.
+ *
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema as a DDL-formatted string.
+ * @param options
+ * options to control how the json is parsed. accepts the same options and the json data
+ * source. See Data
+ * Source Option in the version you use.
*
* @group json_funcs
* @since 2.1.0
@@ -6551,17 +6781,18 @@ object functions {
// scalastyle:off line.size.limit
/**
* (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
- * as keys type, `StructType` or `ArrayType` with the specified schema.
- * Returns `null`, in the case of an unparseable string.
- *
- * @param e a string column containing JSON data.
- * @param schema the schema as a DDL-formatted string.
- * @param options options to control how the json is parsed. accepts the same options and the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
+ * as keys type, `StructType` or `ArrayType` with the specified schema. Returns `null`, in the
+ * case of an unparseable string.
+ *
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema as a DDL-formatted string.
+ * @param options
+ * options to control how the json is parsed. accepts the same options and the json data
+ * source. See Data
+ * Source Option in the version you use.
*
* @group json_funcs
* @since 2.3.0
@@ -6573,11 +6804,13 @@ object functions {
/**
* (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
- * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
- * Returns `null`, in the case of an unparseable string.
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. Returns
+ * `null`, in the case of an unparseable string.
*
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
*
* @group json_funcs
* @since 2.4.0
@@ -6589,17 +6822,18 @@ object functions {
// scalastyle:off line.size.limit
/**
* (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
- * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
- * Returns `null`, in the case of an unparseable string.
- *
- * @param e a string column containing JSON data.
- * @param schema the schema to use when parsing the json string
- * @param options options to control how the json is parsed. accepts the same options and the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
+ * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. Returns
+ * `null`, in the case of an unparseable string.
+ *
+ * @param e
+ * a string column containing JSON data.
+ * @param schema
+ * the schema to use when parsing the json string
+ * @param options
+ * options to control how the json is parsed. accepts the same options and the json data
+ * source. See Data
+ * Source Option in the version you use.
*
* @group json_funcs
* @since 2.4.0
@@ -6620,7 +6854,8 @@ object functions {
* Parses a JSON string and constructs a Variant value. Returns null if the input string is not
* a valid JSON value.
*
- * @param json a string column that contains JSON data.
+ * @param json
+ * a string column that contains JSON data.
*
* @group variant_funcs
* @since 4.0.0
@@ -6637,6 +6872,18 @@ object functions {
*/
def parse_json(json: Column): Column = Column.fn("parse_json", json)
+ /**
+ * Converts a column containing nested inputs (array/map/struct) into a variants where maps and
+ * structs are converted to variant objects which are unordered unlike SQL structs. Input maps
+ * can only have string keys.
+ *
+ * @param col
+ * a column with a nested schema or column name.
+ * @group variant_funcs
+ * @since 4.0.0
+ */
+ def to_variant_object(col: Column): Column = Column.fn("to_variant_object", col)
+
/**
* Check if a variant value is a variant null. Returns true if and only if the input is a
* variant null and false otherwise (including in the case of SQL NULL).
@@ -6705,7 +6952,8 @@ object functions {
/**
* Parses a JSON string and infers its schema in DDL format.
*
- * @param json a JSON string.
+ * @param json
+ * a JSON string.
*
* @group json_funcs
* @since 2.4.0
@@ -6715,7 +6963,8 @@ object functions {
/**
* Parses a JSON string and infers its schema in DDL format.
*
- * @param json a foldable string column containing a JSON string.
+ * @param json
+ * a foldable string column containing a JSON string.
*
* @group json_funcs
* @since 2.4.0
@@ -6726,14 +6975,15 @@ object functions {
/**
* Parses a JSON string and infers its schema in DDL format using options.
*
- * @param json a foldable string column containing JSON data.
- * @param options options to control how the json is parsed. accepts the same options and the
- * json data source.
- * See
- *
- * Data Source Option in the version you use.
- * @return a column with string literal containing schema in DDL format.
+ * @param json
+ * a foldable string column containing JSON data.
+ * @param options
+ * options to control how the json is parsed. accepts the same options and the json data
+ * source. See Data
+ * Source Option in the version you use.
+ * @return
+ * a column with string literal containing schema in DDL format.
*
* @group json_funcs
* @since 3.0.0
@@ -6743,8 +6993,8 @@ object functions {
Column.fnWithOptions("schema_of_json", options.asScala.iterator, json)
/**
- * Returns the number of elements in the outermost JSON array. `NULL` is returned in case of
- * any other valid JSON string, `NULL` or an invalid JSON.
+ * Returns the number of elements in the outermost JSON array. `NULL` is returned in case of any
+ * other valid JSON string, `NULL` or an invalid JSON.
*
* @group json_funcs
* @since 3.5.0
@@ -6753,8 +7003,8 @@ object functions {
/**
* Returns all the keys of the outermost JSON object as an array. If a valid JSON object is
- * given, all the keys of the outermost object will be returned as an array. If it is any
- * other valid JSON string, an invalid JSON string or an empty string, the function returns null.
+ * given, all the keys of the outermost object will be returned as an array. If it is any other
+ * valid JSON string, an invalid JSON string or an empty string, the function returns null.
*
* @group json_funcs
* @since 3.5.0
@@ -6763,19 +7013,18 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` or
- * a `MapType` into a JSON string with the specified schema.
- * Throws an exception, in the case of an unsupported type.
+ * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` or a `MapType` into
+ * a JSON string with the specified schema. Throws an exception, in the case of an unsupported
+ * type.
*
- * @param e a column containing a struct, an array or a map.
- * @param options options to control how the struct column is converted into a json string.
- * accepts the same options and the json data source.
- * See
- *
- * Data Source Option in the version you use.
- * Additionally the function supports the `pretty` option which enables
- * pretty JSON generation.
+ * @param e
+ * a column containing a struct, an array or a map.
+ * @param options
+ * options to control how the struct column is converted into a json string. accepts the same
+ * options and the json data source. See Data
+ * Source Option in the version you use. Additionally the function supports the `pretty`
+ * option which enables pretty JSON generation.
*
* @group json_funcs
* @since 2.1.0
@@ -6786,19 +7035,18 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Java-specific) Converts a column containing a `StructType`, `ArrayType` or
- * a `MapType` into a JSON string with the specified schema.
- * Throws an exception, in the case of an unsupported type.
+ * (Java-specific) Converts a column containing a `StructType`, `ArrayType` or a `MapType` into
+ * a JSON string with the specified schema. Throws an exception, in the case of an unsupported
+ * type.
*
- * @param e a column containing a struct, an array or a map.
- * @param options options to control how the struct column is converted into a json string.
- * accepts the same options and the json data source.
- * See
- *
- * Data Source Option in the version you use.
- * Additionally the function supports the `pretty` option which enables
- * pretty JSON generation.
+ * @param e
+ * a column containing a struct, an array or a map.
+ * @param options
+ * options to control how the struct column is converted into a json string. accepts the same
+ * options and the json data source. See Data
+ * Source Option in the version you use. Additionally the function supports the `pretty`
+ * option which enables pretty JSON generation.
*
* @group json_funcs
* @since 2.1.0
@@ -6808,11 +7056,11 @@ object functions {
to_json(e, options.asScala.toMap)
/**
- * Converts a column containing a `StructType`, `ArrayType` or
- * a `MapType` into a JSON string with the specified schema.
- * Throws an exception, in the case of an unsupported type.
+ * Converts a column containing a `StructType`, `ArrayType` or a `MapType` into a JSON string
+ * with the specified schema. Throws an exception, in the case of an unsupported type.
*
- * @param e a column containing a struct, an array or a map.
+ * @param e
+ * a column containing a struct, an array or a map.
*
* @group json_funcs
* @since 2.1.0
@@ -6822,10 +7070,11 @@ object functions {
/**
* Masks the given string value. The function replaces characters with 'X' or 'x', and numbers
- * with 'n'.
- * This can be useful for creating copies of tables with sensitive information removed.
+ * with 'n'. This can be useful for creating copies of tables with sensitive information
+ * removed.
*
- * @param input string value to mask. Supported types: STRING, VARCHAR, CHAR
+ * @param input
+ * string value to mask. Supported types: STRING, VARCHAR, CHAR
*
* @group string_funcs
* @since 3.5.0
@@ -6834,8 +7083,8 @@ object functions {
/**
* Masks the given string value. The function replaces upper-case characters with specific
- * character, lower-case characters with 'x', and numbers with 'n'.
- * This can be useful for creating copies of tables with sensitive information removed.
+ * character, lower-case characters with 'x', and numbers with 'n'. This can be useful for
+ * creating copies of tables with sensitive information removed.
*
* @param input
* string value to mask. Supported types: STRING, VARCHAR, CHAR
@@ -6850,8 +7099,8 @@ object functions {
/**
* Masks the given string value. The function replaces upper-case and lower-case characters with
- * the characters specified respectively, and numbers with 'n'.
- * This can be useful for creating copies of tables with sensitive information removed.
+ * the characters specified respectively, and numbers with 'n'. This can be useful for creating
+ * copies of tables with sensitive information removed.
*
* @param input
* string value to mask. Supported types: STRING, VARCHAR, CHAR
@@ -6868,8 +7117,8 @@ object functions {
/**
* Masks the given string value. The function replaces upper-case, lower-case characters and
- * numbers with the characters specified respectively.
- * This can be useful for creating copies of tables with sensitive information removed.
+ * numbers with the characters specified respectively. This can be useful for creating copies of
+ * tables with sensitive information removed.
*
* @param input
* string value to mask. Supported types: STRING, VARCHAR, CHAR
@@ -6916,8 +7165,8 @@ object functions {
* Returns length of array or map.
*
* This function returns -1 for null input only if spark.sql.ansi.enabled is false and
- * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input.
- * With the default settings, the function returns null for null input.
+ * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input. With the
+ * default settings, the function returns null for null input.
*
* @group collection_funcs
* @since 1.5.0
@@ -6928,8 +7177,8 @@ object functions {
* Returns length of array or map. This is an alias of `size` function.
*
* This function returns -1 for null input only if spark.sql.ansi.enabled is false and
- * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input.
- * With the default settings, the function returns null for null input.
+ * spark.sql.legacy.sizeOfNull is true. Otherwise, it returns null for null input. With the
+ * default settings, the function returns null for null input.
*
* @group collection_funcs
* @since 3.5.0
@@ -6937,9 +7186,9 @@ object functions {
def cardinality(e: Column): Column = Column.fn("cardinality", e)
/**
- * Sorts the input array for the given column in ascending order,
- * according to the natural ordering of the array elements.
- * Null elements will be placed at the beginning of the returned array.
+ * Sorts the input array for the given column in ascending order, according to the natural
+ * ordering of the array elements. Null elements will be placed at the beginning of the returned
+ * array.
*
* @group array_funcs
* @since 1.5.0
@@ -6947,11 +7196,10 @@ object functions {
def sort_array(e: Column): Column = sort_array(e, asc = true)
/**
- * Sorts the input array for the given column in ascending or descending order,
- * according to the natural ordering of the array elements. NaN is greater than any non-NaN
- * elements for double/float type. Null elements will be placed at the beginning of the returned
- * array in ascending order or
- * at the end of the returned array in descending order.
+ * Sorts the input array for the given column in ascending or descending order, according to the
+ * natural ordering of the array elements. NaN is greater than any non-NaN elements for
+ * double/float type. Null elements will be placed at the beginning of the returned array in
+ * ascending order or at the end of the returned array in descending order.
*
* @group array_funcs
* @since 1.5.0
@@ -6987,8 +7235,9 @@ object functions {
/**
* Aggregate function: returns a list of objects with duplicates.
*
- * @note The function is non-deterministic because the order of collected results depends
- * on the order of the rows which may be non-deterministic after a shuffle.
+ * @note
+ * The function is non-deterministic because the order of collected results depends on the
+ * order of the rows which may be non-deterministic after a shuffle.
* @group agg_funcs
* @since 3.5.0
*/
@@ -6997,7 +7246,8 @@ object functions {
/**
* Returns a random permutation of the given array.
*
- * @note The function is non-deterministic.
+ * @note
+ * The function is non-deterministic.
*
* @group array_funcs
* @since 2.4.0
@@ -7012,8 +7262,8 @@ object functions {
def reverse(e: Column): Column = Column.fn("reverse", e)
/**
- * Creates a single array from an array of arrays. If a structure of nested arrays is deeper than
- * two levels, only one level of nesting is removed.
+ * Creates a single array from an array of arrays. If a structure of nested arrays is deeper
+ * than two levels, only one level of nesting is removed.
* @group array_funcs
* @since 2.4.0
*/
@@ -7029,8 +7279,8 @@ object functions {
Column.fn("sequence", start, stop, step)
/**
- * Generate a sequence of integers from start to stop,
- * incrementing by 1 if start is less than or equal to stop, otherwise -1.
+ * Generate a sequence of integers from start to stop, incrementing by 1 if start is less than
+ * or equal to stop, otherwise -1.
*
* @group array_funcs
* @since 2.4.0
@@ -7038,8 +7288,8 @@ object functions {
def sequence(start: Column, stop: Column): Column = Column.fn("sequence", start, stop)
/**
- * Creates an array containing the left argument repeated the number of times given by the
- * right argument.
+ * Creates an array containing the left argument repeated the number of times given by the right
+ * argument.
*
* @group array_funcs
* @since 2.4.0
@@ -7047,8 +7297,8 @@ object functions {
def array_repeat(left: Column, right: Column): Column = Column.fn("array_repeat", left, right)
/**
- * Creates an array containing the left argument repeated the number of times given by the
- * right argument.
+ * Creates an array containing the left argument repeated the number of times given by the right
+ * argument.
*
* @group array_funcs
* @since 2.4.0
@@ -7113,14 +7363,15 @@ object functions {
* Parses a column containing a CSV string into a `StructType` with the specified schema.
* Returns `null`, in the case of an unparseable string.
*
- * @param e a string column containing CSV data.
- * @param schema the schema to use when parsing the CSV string
- * @param options options to control how the CSV is parsed. accepts the same options and the
- * CSV data source.
- * See
- *
- * Data Source Option in the version you use.
+ * @param e
+ * a string column containing CSV data.
+ * @param schema
+ * the schema to use when parsing the CSV string
+ * @param options
+ * options to control how the CSV is parsed. accepts the same options and the CSV data source.
+ * See Data
+ * Source Option in the version you use.
*
* @group csv_funcs
* @since 3.0.0
@@ -7131,17 +7382,18 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Java-specific) Parses a column containing a CSV string into a `StructType`
- * with the specified schema. Returns `null`, in the case of an unparseable string.
+ * (Java-specific) Parses a column containing a CSV string into a `StructType` with the
+ * specified schema. Returns `null`, in the case of an unparseable string.
*
- * @param e a string column containing CSV data.
- * @param schema the schema to use when parsing the CSV string
- * @param options options to control how the CSV is parsed. accepts the same options and the
- * CSV data source.
- * See
- *
- * Data Source Option in the version you use.
+ * @param e
+ * a string column containing CSV data.
+ * @param schema
+ * the schema to use when parsing the CSV string
+ * @param options
+ * options to control how the CSV is parsed. accepts the same options and the CSV data source.
+ * See Data
+ * Source Option in the version you use.
*
* @group csv_funcs
* @since 3.0.0
@@ -7156,7 +7408,8 @@ object functions {
/**
* Parses a CSV string and infers its schema in DDL format.
*
- * @param csv a CSV string.
+ * @param csv
+ * a CSV string.
*
* @group csv_funcs
* @since 3.0.0
@@ -7166,7 +7419,8 @@ object functions {
/**
* Parses a CSV string and infers its schema in DDL format.
*
- * @param csv a foldable string column containing a CSV string.
+ * @param csv
+ * a foldable string column containing a CSV string.
*
* @group csv_funcs
* @since 3.0.0
@@ -7177,14 +7431,15 @@ object functions {
/**
* Parses a CSV string and infers its schema in DDL format using options.
*
- * @param csv a foldable string column containing a CSV string.
- * @param options options to control how the CSV is parsed. accepts the same options and the
- * CSV data source.
- * See
- *
- * Data Source Option in the version you use.
- * @return a column with string literal containing schema in DDL format.
+ * @param csv
+ * a foldable string column containing a CSV string.
+ * @param options
+ * options to control how the CSV is parsed. accepts the same options and the CSV data source.
+ * See Data
+ * Source Option in the version you use.
+ * @return
+ * a column with string literal containing schema in DDL format.
*
* @group csv_funcs
* @since 3.0.0
@@ -7195,16 +7450,16 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Java-specific) Converts a column containing a `StructType` into a CSV string with
- * the specified schema. Throws an exception, in the case of an unsupported type.
+ * (Java-specific) Converts a column containing a `StructType` into a CSV string with the
+ * specified schema. Throws an exception, in the case of an unsupported type.
*
- * @param e a column containing a struct.
- * @param options options to control how the struct column is converted into a CSV string.
- * It accepts the same options and the CSV data source.
- * See
- *
- * Data Source Option in the version you use.
+ * @param e
+ * a column containing a struct.
+ * @param options
+ * options to control how the struct column is converted into a CSV string. It accepts the
+ * same options and the CSV data source. See Data
+ * Source Option in the version you use.
*
* @group csv_funcs
* @since 3.0.0
@@ -7217,7 +7472,8 @@ object functions {
* Converts a column containing a `StructType` into a CSV string with the specified schema.
* Throws an exception, in the case of an unsupported type.
*
- * @param e a column containing a struct.
+ * @param e
+ * a column containing a struct.
*
* @group csv_funcs
* @since 3.0.0
@@ -7226,17 +7482,18 @@ object functions {
// scalastyle:off line.size.limit
/**
- * Parses a column containing a XML string into the data type corresponding to the specified schema.
- * Returns `null`, in the case of an unparseable string.
- *
- * @param e a string column containing XML data.
- * @param schema the schema to use when parsing the XML string
- * @param options options to control how the XML is parsed. accepts the same options and the
- * XML data source.
- * See
- *
- * Data Source Option in the version you use.
+ * Parses a column containing a XML string into the data type corresponding to the specified
+ * schema. Returns `null`, in the case of an unparseable string.
+ *
+ * @param e
+ * a string column containing XML data.
+ * @param schema
+ * the schema to use when parsing the XML string
+ * @param options
+ * options to control how the XML is parsed. accepts the same options and the XML data source.
+ * See Data
+ * Source Option in the version you use.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7246,18 +7503,18 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Java-specific) Parses a column containing a XML string into a `StructType`
- * with the specified schema.
- * Returns `null`, in the case of an unparseable string.
+ * (Java-specific) Parses a column containing a XML string into a `StructType` with the
+ * specified schema. Returns `null`, in the case of an unparseable string.
*
- * @param e a string column containing XML data.
- * @param schema the schema as a DDL-formatted string.
- * @param options options to control how the XML is parsed. accepts the same options and the
- * xml data source.
- * See
- *
- * Data Source Option in the version you use.
+ * @param e
+ * a string column containing XML data.
+ * @param schema
+ * the schema as a DDL-formatted string.
+ * @param options
+ * options to control how the XML is parsed. accepts the same options and the xml data source.
+ * See Data
+ * Source Option in the version you use.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7268,11 +7525,13 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Java-specific) Parses a column containing a XML string into a `StructType`
- * with the specified schema. Returns `null`, in the case of an unparseable string.
+ * (Java-specific) Parses a column containing a XML string into a `StructType` with the
+ * specified schema. Returns `null`, in the case of an unparseable string.
*
- * @param e a string column containing XML data.
- * @param schema the schema to use when parsing the XML string
+ * @param e
+ * a string column containing XML data.
+ * @param schema
+ * the schema to use when parsing the XML string
* @group xml_funcs
* @since 4.0.0
*/
@@ -7283,17 +7542,18 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Java-specific) Parses a column containing a XML string into a `StructType`
- * with the specified schema. Returns `null`, in the case of an unparseable string.
- *
- * @param e a string column containing XML data.
- * @param schema the schema to use when parsing the XML string
- * @param options options to control how the XML is parsed. accepts the same options and the
- * XML data source.
- * See
- *
- * Data Source Option in the version you use.
+ * (Java-specific) Parses a column containing a XML string into a `StructType` with the
+ * specified schema. Returns `null`, in the case of an unparseable string.
+ *
+ * @param e
+ * a string column containing XML data.
+ * @param schema
+ * the schema to use when parsing the XML string
+ * @param options
+ * options to control how the XML is parsed. accepts the same options and the XML data source.
+ * See Data
+ * Source Option in the version you use.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7302,13 +7562,14 @@ object functions {
from_xml(e, schema, options.asScala.iterator)
/**
- * Parses a column containing a XML string into the data type
- * corresponding to the specified schema.
- * Returns `null`, in the case of an unparseable string.
+ * Parses a column containing a XML string into the data type corresponding to the specified
+ * schema. Returns `null`, in the case of an unparseable string.
+ *
+ * @param e
+ * a string column containing XML data.
+ * @param schema
+ * the schema to use when parsing the XML string
*
- * @param e a string column containing XML data.
- * @param schema the schema to use when parsing the XML string
-
* @group xml_funcs
* @since 4.0.0
*/
@@ -7322,7 +7583,8 @@ object functions {
/**
* Parses a XML string and infers its schema in DDL format.
*
- * @param xml a XML string.
+ * @param xml
+ * a XML string.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7331,7 +7593,8 @@ object functions {
/**
* Parses a XML string and infers its schema in DDL format.
*
- * @param xml a foldable string column containing a XML string.
+ * @param xml
+ * a foldable string column containing a XML string.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7342,14 +7605,15 @@ object functions {
/**
* Parses a XML string and infers its schema in DDL format using options.
*
- * @param xml a foldable string column containing XML data.
- * @param options options to control how the xml is parsed. accepts the same options and the
- * XML data source.
- * See
- *
- * Data Source Option in the version you use.
- * @return a column with string literal containing schema in DDL format.
+ * @param xml
+ * a foldable string column containing XML data.
+ * @param options
+ * options to control how the xml is parsed. accepts the same options and the XML data source.
+ * See Data
+ * Source Option in the version you use.
+ * @return
+ * a column with string literal containing schema in DDL format.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7360,16 +7624,16 @@ object functions {
// scalastyle:off line.size.limit
/**
- * (Java-specific) Converts a column containing a `StructType` into a XML string with
- * the specified schema. Throws an exception, in the case of an unsupported type.
+ * (Java-specific) Converts a column containing a `StructType` into a XML string with the
+ * specified schema. Throws an exception, in the case of an unsupported type.
*
- * @param e a column containing a struct.
- * @param options options to control how the struct column is converted into a XML string.
- * It accepts the same options as the XML data source.
- * See
- *
- * Data Source Option in the version you use.
+ * @param e
+ * a column containing a struct.
+ * @param options
+ * options to control how the struct column is converted into a XML string. It accepts the
+ * same options as the XML data source. See Data
+ * Source Option in the version you use.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7381,7 +7645,8 @@ object functions {
* Converts a column containing a `StructType` into a XML string with the specified schema.
* Throws an exception, in the case of an unsupported type.
*
- * @param e a column containing a struct.
+ * @param e
+ * a column containing a struct.
* @group xml_funcs
* @since 4.0.0
*/
@@ -7430,8 +7695,8 @@ object functions {
Column.fn("xpath_boolean", xml, path)
/**
- * Returns a double value, the value zero if no match is found,
- * or NaN if a match is found but the value is non-numeric.
+ * Returns a double value, the value zero if no match is found, or NaN if a match is found but
+ * the value is non-numeric.
*
* @group xml_funcs
* @since 3.5.0
@@ -7440,8 +7705,8 @@ object functions {
Column.fn("xpath_double", xml, path)
/**
- * Returns a double value, the value zero if no match is found,
- * or NaN if a match is found but the value is non-numeric.
+ * Returns a double value, the value zero if no match is found, or NaN if a match is found but
+ * the value is non-numeric.
*
* @group xml_funcs
* @since 3.5.0
@@ -7450,8 +7715,8 @@ object functions {
Column.fn("xpath_number", xml, path)
/**
- * Returns a float value, the value zero if no match is found,
- * or NaN if a match is found but the value is non-numeric.
+ * Returns a float value, the value zero if no match is found, or NaN if a match is found but
+ * the value is non-numeric.
*
* @group xml_funcs
* @since 3.5.0
@@ -7460,8 +7725,8 @@ object functions {
Column.fn("xpath_float", xml, path)
/**
- * Returns an integer value, or the value zero if no match is found,
- * or a match is found but the value is non-numeric.
+ * Returns an integer value, or the value zero if no match is found, or a match is found but the
+ * value is non-numeric.
*
* @group xml_funcs
* @since 3.5.0
@@ -7470,8 +7735,8 @@ object functions {
Column.fn("xpath_int", xml, path)
/**
- * Returns a long integer value, or the value zero if no match is found,
- * or a match is found but the value is non-numeric.
+ * Returns a long integer value, or the value zero if no match is found, or a match is found but
+ * the value is non-numeric.
*
* @group xml_funcs
* @since 3.5.0
@@ -7480,8 +7745,8 @@ object functions {
Column.fn("xpath_long", xml, path)
/**
- * Returns a short integer value, or the value zero if no match is found,
- * or a match is found but the value is non-numeric.
+ * Returns a short integer value, or the value zero if no match is found, or a match is found
+ * but the value is non-numeric.
*
* @group xml_funcs
* @since 3.5.0
@@ -7507,13 +7772,16 @@ object functions {
def hours(e: Column): Column = partitioning.hours(e)
/**
- * Converts the timestamp without time zone `sourceTs`
- * from the `sourceTz` time zone to `targetTz`.
+ * Converts the timestamp without time zone `sourceTs` from the `sourceTz` time zone to
+ * `targetTz`.
*
- * @param sourceTz the time zone for the input timestamp. If it is missed,
- * the current session time zone is used as the source time zone.
- * @param targetTz the time zone to which the input timestamp should be converted.
- * @param sourceTs a timestamp without time zone.
+ * @param sourceTz
+ * the time zone for the input timestamp. If it is missed, the current session time zone is
+ * used as the source time zone.
+ * @param targetTz
+ * the time zone to which the input timestamp should be converted.
+ * @param sourceTs
+ * a timestamp without time zone.
* @group datetime_funcs
* @since 3.5.0
*/
@@ -7521,11 +7789,12 @@ object functions {
Column.fn("convert_timezone", sourceTz, targetTz, sourceTs)
/**
- * Converts the timestamp without time zone `sourceTs`
- * from the current time zone to `targetTz`.
+ * Converts the timestamp without time zone `sourceTs` from the current time zone to `targetTz`.
*
- * @param targetTz the time zone to which the input timestamp should be converted.
- * @param sourceTs a timestamp without time zone.
+ * @param targetTz
+ * the time zone to which the input timestamp should be converted.
+ * @param sourceTs
+ * a timestamp without time zone.
* @group datetime_funcs
* @since 3.5.0
*/
@@ -7818,8 +8087,8 @@ object functions {
def isnotnull(col: Column): Column = Column.fn("isnotnull", col)
/**
- * Returns same result as the EQUAL(=) operator for non-null operands,
- * but returns true if both are null, false if one of the them is null.
+ * Returns same result as the EQUAL(=) operator for non-null operands, but returns true if both
+ * are null, false if one of the them is null.
*
* @group predicate_funcs
* @since 3.5.0
@@ -7908,15 +8177,15 @@ object functions {
|}""".stripMargin)
}
- */
+ */
//////////////////////////////////////////////////////////////////////////////////////////////
// Scala UDF functions
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Obtains a `UserDefinedFunction` that wraps the given `Aggregator`
- * so that it may be used with untyped Data Frames.
+ * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` so that it may be used with
+ * untyped Data Frames.
* {{{
* val agg = // Aggregator[IN, BUF, OUT]
*
@@ -7928,24 +8197,30 @@ object functions {
* spark.udf.register("myAggName", udaf(agg))
* }}}
*
- * @tparam IN the aggregator input type
- * @tparam BUF the aggregating buffer type
- * @tparam OUT the finalized output type
+ * @tparam IN
+ * the aggregator input type
+ * @tparam BUF
+ * the aggregating buffer type
+ * @tparam OUT
+ * the finalized output type
*
- * @param agg the typed Aggregator
+ * @param agg
+ * the typed Aggregator
*
- * @return a UserDefinedFunction that can be used as an aggregating expression.
+ * @return
+ * a UserDefinedFunction that can be used as an aggregating expression.
*
* @group udf_funcs
- * @note The input encoder is inferred from the input type IN.
+ * @note
+ * The input encoder is inferred from the input type IN.
*/
def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = {
udaf(agg, ScalaReflection.encoderFor[IN])
}
/**
- * Obtains a `UserDefinedFunction` that wraps the given `Aggregator`
- * so that it may be used with untyped Data Frames.
+ * Obtains a `UserDefinedFunction` that wraps the given `Aggregator` so that it may be used with
+ * untyped Data Frames.
* {{{
* Aggregator agg = // custom Aggregator
* Encoder enc = // input encoder
@@ -7958,18 +8233,24 @@ object functions {
* spark.udf.register("myAggName", udaf(agg, enc))
* }}}
*
- * @tparam IN the aggregator input type
- * @tparam BUF the aggregating buffer type
- * @tparam OUT the finalized output type
+ * @tparam IN
+ * the aggregator input type
+ * @tparam BUF
+ * the aggregating buffer type
+ * @tparam OUT
+ * the finalized output type
*
- * @param agg the typed Aggregator
- * @param inputEncoder a specific input encoder to use
+ * @param agg
+ * the typed Aggregator
+ * @param inputEncoder
+ * a specific input encoder to use
*
- * @return a UserDefinedFunction that can be used as an aggregating expression
+ * @return
+ * a UserDefinedFunction that can be used as an aggregating expression
*
* @group udf_funcs
- * @note This overloading takes an explicit input encoder, to support UDAF
- * declarations in Java.
+ * @note
+ * This overloading takes an explicit input encoder, to support UDAF declarations in Java.
*/
def udaf[IN, BUF, OUT](
agg: Aggregator[IN, BUF, OUT],
@@ -7978,10 +8259,10 @@ object functions {
}
/**
- * Defines a Scala closure of 0 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 0 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -7991,10 +8272,10 @@ object functions {
}
/**
- * Defines a Scala closure of 1 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 1 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
@@ -8004,120 +8285,242 @@ object functions {
}
/**
- * Defines a Scala closure of 2 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 2 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]])
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](
+ f: Function2[A1, A2, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]])
}
/**
- * Defines a Scala closure of 3 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 3 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]])
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
+ f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]])
}
/**
- * Defines a Scala closure of 4 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 4 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]])
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](
+ f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]])
}
/**
- * Defines a Scala closure of 5 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 5 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]])
+ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](
+ f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]])
}
/**
- * Defines a Scala closure of 6 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 6 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]])
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]])
}
/**
- * Defines a Scala closure of 7 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 7 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]])
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]])
}
/**
- * Defines a Scala closure of 8 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 8 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]])
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]])
}
/**
- * Defines a Scala closure of 9 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 9 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]])
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]])
}
/**
- * Defines a Scala closure of 10 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the Scala closure's
- * signature. By default the returned UDF is deterministic. To change it to
- * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Scala closure of 10 arguments as user-defined function (UDF). The data types are
+ * automatically inferred based on the Scala closure's signature. By default the returned UDF is
+ * deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 1.3.0
*/
- def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
- SparkUserDefinedFunction(f, implicitly[TypeTag[RT]], implicitly[TypeTag[A1]], implicitly[TypeTag[A2]], implicitly[TypeTag[A3]], implicitly[TypeTag[A4]], implicitly[TypeTag[A5]], implicitly[TypeTag[A6]], implicitly[TypeTag[A7]], implicitly[TypeTag[A8]], implicitly[TypeTag[A9]], implicitly[TypeTag[A10]])
+ def udf[
+ RT: TypeTag,
+ A1: TypeTag,
+ A2: TypeTag,
+ A3: TypeTag,
+ A4: TypeTag,
+ A5: TypeTag,
+ A6: TypeTag,
+ A7: TypeTag,
+ A8: TypeTag,
+ A9: TypeTag,
+ A10: TypeTag](
+ f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
+ SparkUserDefinedFunction(
+ f,
+ implicitly[TypeTag[RT]],
+ implicitly[TypeTag[A1]],
+ implicitly[TypeTag[A2]],
+ implicitly[TypeTag[A3]],
+ implicitly[TypeTag[A4]],
+ implicitly[TypeTag[A5]],
+ implicitly[TypeTag[A6]],
+ implicitly[TypeTag[A7]],
+ implicitly[TypeTag[A8]],
+ implicitly[TypeTag[A9]],
+ implicitly[TypeTag[A10]])
}
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -8125,10 +8528,10 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Defines a Java UDF0 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF0 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8138,10 +8541,10 @@ object functions {
}
/**
- * Defines a Java UDF1 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF1 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8151,10 +8554,10 @@ object functions {
}
/**
- * Defines a Java UDF2 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF2 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8164,10 +8567,10 @@ object functions {
}
/**
- * Defines a Java UDF3 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF3 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8177,10 +8580,10 @@ object functions {
}
/**
- * Defines a Java UDF4 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF4 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8190,10 +8593,10 @@ object functions {
}
/**
- * Defines a Java UDF5 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF5 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8203,10 +8606,10 @@ object functions {
}
/**
- * Defines a Java UDF6 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF6 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8216,10 +8619,10 @@ object functions {
}
/**
- * Defines a Java UDF7 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF7 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8229,10 +8632,10 @@ object functions {
}
/**
- * Defines a Java UDF8 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF8 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8242,10 +8645,10 @@ object functions {
}
/**
- * Defines a Java UDF9 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF9 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
@@ -8255,15 +8658,17 @@ object functions {
}
/**
- * Defines a Java UDF10 instance as user-defined function (UDF).
- * The caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * Defines a Java UDF10 instance as user-defined function (UDF). The caller must specify the
+ * output data type, and there is no automatic input type coercion. By default the returned UDF
+ * is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* @group udf_funcs
* @since 2.3.0
*/
- def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
+ def udf(
+ f: UDF10[_, _, _, _, _, _, _, _, _, _, _],
+ returnType: DataType): UserDefinedFunction = {
SparkUserDefinedFunction(ToScalaUDF(f), returnType, 10)
}
@@ -8273,8 +8678,8 @@ object functions {
/**
* Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant,
* the caller must specify the output data type, and there is no automatic input type coercion.
- * By default the returned UDF is deterministic. To change it to nondeterministic, call the
- * API `UserDefinedFunction.asNondeterministic()`.
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the API
+ * `UserDefinedFunction.asNondeterministic()`.
*
* Note that, although the Scala closure can have primitive-type function argument, it doesn't
* work well with null values. Because the Scala closure is passed in as Any type, there is no
@@ -8283,14 +8688,18 @@ object functions {
* default value of the Java type for the null argument, e.g. `udf((x: Int) => x, IntegerType)`,
* the result is 0 for null input.
*
- * @param f A closure in Scala
- * @param dataType The output data type of the UDF
+ * @param f
+ * A closure in Scala
+ * @param dataType
+ * The output data type of the UDF
*
* @group udf_funcs
* @since 2.0.0
*/
- @deprecated("Scala `udf` method with return type parameter is deprecated. " +
- "Please use Scala `udf` method without return type parameter.", "3.0.0")
+ @deprecated(
+ "Scala `udf` method with return type parameter is deprecated. " +
+ "Please use Scala `udf` method without return type parameter.",
+ "3.0.0")
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
if (!SqlApiConf.get.legacyAllowUntypedScalaUDFs) {
throw CompilationErrors.usingUntypedScalaUDFError()
@@ -8309,8 +8718,7 @@ object functions {
def callUDF(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*)
/**
- * Call an user-defined function.
- * Example:
+ * Call an user-defined function. Example:
* {{{
* import org.apache.spark.sql._
*
@@ -8329,9 +8737,10 @@ object functions {
/**
* Call a SQL function.
*
- * @param funcName function name that follows the SQL identifier syntax
- * (can be quoted, can be qualified)
- * @param cols the expression parameters of function
+ * @param funcName
+ * function name that follows the SQL identifier syntax (can be quoted, can be qualified)
+ * @param cols
+ * the expression parameters of function
* @group normal_funcs
* @since 3.5.0
*/
@@ -8352,7 +8761,7 @@ object functions {
// API in the same way. Once we land this fix, should deprecate
// functions.hours, days, months, years and bucket.
object partitioning {
- // scalastyle:on
+ // scalastyle:on
/**
* (Scala-specific) A transform for timestamps and dates to partition data into years.
*
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
index 5762f9f6f5668..555a567053080 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
@@ -25,8 +25,8 @@ import org.apache.spark.util.SparkClassUtils
/**
* Configuration for all objects that are placed in the `sql/api` project. The normal way of
- * accessing this class is through `SqlApiConf.get`. If this code is being used with sql/core
- * then its values are bound to the currently set SQLConf. With Spark Connect, it will default to
+ * accessing this class is through `SqlApiConf.get`. If this code is being used with sql/core then
+ * its values are bound to the currently set SQLConf. With Spark Connect, it will default to
* hardcoded values.
*/
private[sql] trait SqlApiConf {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
index b7b8e14afb387..13ef13e5894e0 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
@@ -21,9 +21,9 @@ import java.util.concurrent.atomic.AtomicReference
/**
* SqlApiConfHelper is created to avoid a deadlock during a concurrent access to SQLConf and
* SqlApiConf, which is because SQLConf and SqlApiConf tries to load each other upon
- * initializations. SqlApiConfHelper is private to sql package and is not supposed to be
- * accessed by end users. Variables and methods within SqlApiConfHelper are defined to
- * be used by SQLConf and SqlApiConf only.
+ * initializations. SqlApiConfHelper is private to sql package and is not supposed to be accessed
+ * by end users. Variables and methods within SqlApiConfHelper are defined to be used by SQLConf
+ * and SqlApiConf only.
*/
private[sql] object SqlApiConfHelper {
// Shared keys.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala
index 25ea37fadfa91..4d476108d9ec5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/ToScalaUDF.scala
@@ -16,14 +16,56 @@
*/
package org.apache.spark.sql.internal
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.api.java.function.{CoGroupFunction, FilterFunction, FlatMapFunction, FlatMapGroupsFunction, FlatMapGroupsWithStateFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapGroupsFunction, MapGroupsWithStateFunction, MapPartitionsFunction, ReduceFunction}
import org.apache.spark.sql.api.java._
+import org.apache.spark.sql.streaming.GroupState
/**
- * Helper class that provides conversions from org.apache.spark.sql.api.java.Function* to
- * scala.Function*.
+ * Helper class that provides conversions from org.apache.spark.sql.api.java.Function* and
+ * org.apache.spark.api.java.function.* to scala functions.
+ *
+ * Please note that this class is being used in Spark Connect Scala UDFs. We need to be careful
+ * with any modifications to this class, otherwise we will break backwards compatibility.
+ * Concretely this means you can only add methods to this class. You cannot rename the class, move
+ * it, change its `serialVersionUID`, remove methods, change method signatures, or change method
+ * semantics.
*/
@SerialVersionUID(2019907615267866045L)
private[sql] object ToScalaUDF extends Serializable {
+ def apply[T](f: FilterFunction[T]): T => Boolean = f.call
+
+ def apply[T](f: ReduceFunction[T]): (T, T) => T = f.call
+
+ def apply[V, W](f: MapFunction[V, W]): V => W = f.call
+
+ def apply[K, V, U](f: MapGroupsFunction[K, V, U]): (K, Iterator[V]) => U =
+ (key, values) => f.call(key, values.asJava)
+
+ def apply[K, V, S, U](
+ f: MapGroupsWithStateFunction[K, V, S, U]): (K, Iterator[V], GroupState[S]) => U =
+ (key, values, state) => f.call(key, values.asJava, state)
+
+ def apply[V, U](f: MapPartitionsFunction[V, U]): Iterator[V] => Iterator[U] =
+ values => f.call(values.asJava).asScala
+
+ def apply[K, V, U](f: FlatMapGroupsFunction[K, V, U]): (K, Iterator[V]) => Iterator[U] =
+ (key, values) => f.call(key, values.asJava).asScala
+
+ def apply[K, V, S, U](f: FlatMapGroupsWithStateFunction[K, V, S, U])
+ : (K, Iterator[V], GroupState[S]) => Iterator[U] =
+ (key, values, state) => f.call(key, values.asJava, state).asScala
+
+ def apply[K, V, U, R](
+ f: CoGroupFunction[K, V, U, R]): (K, Iterator[V], Iterator[U]) => Iterator[R] =
+ (key, left, right) => f.call(key, left.asJava, right.asJava).asScala
+
+ def apply[V](f: ForeachFunction[V]): V => Unit = f.call
+
+ def apply[V](f: ForeachPartitionFunction[V]): Iterator[V] => Unit =
+ values => f.call(values.asJava)
+
// scalastyle:off line.size.limit
/* register 0-22 were generated by this script
@@ -38,171 +80,757 @@ private[sql] object ToScalaUDF extends Serializable {
|/**
| * Create a scala.Function$i wrapper for a org.apache.spark.sql.api.java.UDF$i instance.
| */
- |def apply(f: UDF$i[$extTypeArgs]): AnyRef = {
+ |def apply(f: UDF$i[$extTypeArgs]): Function$i[$anyTypeArgs] = {
| $funcCall
|}""".stripMargin)
}
- */
+ */
/**
* Create a scala.Function0 wrapper for a org.apache.spark.sql.api.java.UDF0 instance.
*/
- def apply(f: UDF0[_]): AnyRef = {
- () => f.asInstanceOf[UDF0[Any]].call()
+ def apply(f: UDF0[_]): () => Any = { () =>
+ f.asInstanceOf[UDF0[Any]].call()
}
/**
* Create a scala.Function1 wrapper for a org.apache.spark.sql.api.java.UDF1 instance.
*/
- def apply(f: UDF1[_, _]): AnyRef = {
+ def apply(f: UDF1[_, _]): (Any) => Any = {
f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
}
/**
* Create a scala.Function2 wrapper for a org.apache.spark.sql.api.java.UDF2 instance.
*/
- def apply(f: UDF2[_, _, _]): AnyRef = {
+ def apply(f: UDF2[_, _, _]): (Any, Any) => Any = {
f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
}
/**
* Create a scala.Function3 wrapper for a org.apache.spark.sql.api.java.UDF3 instance.
*/
- def apply(f: UDF3[_, _, _, _]): AnyRef = {
+ def apply(f: UDF3[_, _, _, _]): (Any, Any, Any) => Any = {
f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
}
/**
* Create a scala.Function4 wrapper for a org.apache.spark.sql.api.java.UDF4 instance.
*/
- def apply(f: UDF4[_, _, _, _, _]): AnyRef = {
+ def apply(f: UDF4[_, _, _, _, _]): (Any, Any, Any, Any) => Any = {
f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
}
/**
* Create a scala.Function5 wrapper for a org.apache.spark.sql.api.java.UDF5 instance.
*/
- def apply(f: UDF5[_, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF5[_, _, _, _, _, _]): (Any, Any, Any, Any, Any) => Any = {
+ f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]]
+ .call(_: Any, _: Any, _: Any, _: Any, _: Any)
}
/**
* Create a scala.Function6 wrapper for a org.apache.spark.sql.api.java.UDF6 instance.
*/
- def apply(f: UDF6[_, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF6[_, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any) => Any = {
+ f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]]
+ .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
}
/**
* Create a scala.Function7 wrapper for a org.apache.spark.sql.api.java.UDF7 instance.
*/
- def apply(f: UDF7[_, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF7[_, _, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any, Any) => Any = {
+ f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
}
/**
* Create a scala.Function8 wrapper for a org.apache.spark.sql.api.java.UDF8 instance.
*/
- def apply(f: UDF8[_, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(
+ f: UDF8[_, _, _, _, _, _, _, _, _]): (Any, Any, Any, Any, Any, Any, Any, Any) => Any = {
+ f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
}
/**
* Create a scala.Function9 wrapper for a org.apache.spark.sql.api.java.UDF9 instance.
*/
- def apply(f: UDF9[_, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF9[_, _, _, _, _, _, _, _, _, _])
+ : (Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any = {
+ f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
}
/**
* Create a scala.Function10 wrapper for a org.apache.spark.sql.api.java.UDF10 instance.
*/
- def apply(f: UDF10[_, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF10[_, _, _, _, _, _, _, _, _, _, _])
+ : Function10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
+ f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
}
/**
* Create a scala.Function11 wrapper for a org.apache.spark.sql.api.java.UDF11 instance.
*/
- def apply(f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _])
+ : Function11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
+ f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function12 wrapper for a org.apache.spark.sql.api.java.UDF12 instance.
*/
- def apply(f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _])
+ : Function12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
+ f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function13 wrapper for a org.apache.spark.sql.api.java.UDF13 instance.
*/
- def apply(f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _])
+ : Function13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
+ f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function14 wrapper for a org.apache.spark.sql.api.java.UDF14 instance.
*/
- def apply(f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _])
+ : Function14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] = {
+ f.asInstanceOf[
+ UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function15 wrapper for a org.apache.spark.sql.api.java.UDF15 instance.
*/
- def apply(f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function15[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[
+ UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function16 wrapper for a org.apache.spark.sql.api.java.UDF16 instance.
*/
- def apply(f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function16[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[
+ UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function17 wrapper for a org.apache.spark.sql.api.java.UDF17 instance.
*/
- def apply(f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function17[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[UDF17[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function18 wrapper for a org.apache.spark.sql.api.java.UDF18 instance.
*/
- def apply(f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function18[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[UDF18[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function19 wrapper for a org.apache.spark.sql.api.java.UDF19 instance.
*/
- def apply(f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function19[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[UDF19[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function20 wrapper for a org.apache.spark.sql.api.java.UDF20 instance.
*/
- def apply(f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function20[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[UDF20[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function21 wrapper for a org.apache.spark.sql.api.java.UDF21 instance.
*/
- def apply(f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(
+ f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function21[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[UDF21[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
/**
* Create a scala.Function22 wrapper for a org.apache.spark.sql.api.java.UDF22 instance.
*/
- def apply(f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): AnyRef = {
- f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
+ def apply(
+ f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]): Function22[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any] = {
+ f.asInstanceOf[UDF22[
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any,
+ Any]]
+ .call(
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any,
+ _: Any)
}
// scalastyle:on line.size.limit
}
+
+/**
+ * Adaptors from one UDF shape to another. For example adapting a foreach function for use in
+ * foreachPartition.
+ *
+ * Please note that this class is being used in Spark Connect Scala UDFs. We need to be careful
+ * with any modifications to this class, otherwise we will break backwards compatibility.
+ * Concretely this means you can only add methods to this class. You cannot rename the class, move
+ * it, change its `serialVersionUID`, remove methods, change method signatures, or change method
+ * semantics.
+ */
+@SerialVersionUID(0L) // TODO
+object UDFAdaptors extends Serializable {
+ def flatMapToMapPartitions[V, U](f: V => IterableOnce[U]): Iterator[V] => Iterator[U] =
+ values => values.flatMap(f)
+
+ def flatMapToMapPartitions[V, U](f: FlatMapFunction[V, U]): Iterator[V] => Iterator[U] =
+ values => values.flatMap(v => f.call(v).asScala)
+
+ def mapToMapPartitions[V, U](f: V => U): Iterator[V] => Iterator[U] = values => values.map(f)
+
+ def mapToMapPartitions[V, U](f: MapFunction[V, U]): Iterator[V] => Iterator[U] =
+ values => values.map(f.call)
+
+ def foreachToForeachPartition[T](f: T => Unit): Iterator[T] => Unit =
+ values => values.foreach(f)
+
+ def foreachToForeachPartition[T](f: ForeachFunction[T]): Iterator[T] => Unit =
+ values => values.foreach(f.call)
+
+ def foreachPartitionToMapPartitions[V, U](f: Iterator[V] => Unit): Iterator[V] => Iterator[U] =
+ values => {
+ f(values)
+ Iterator.empty[U]
+ }
+
+ def iterableOnceToSeq[A, B](f: A => IterableOnce[B]): A => Seq[B] =
+ value => f(value).iterator.toSeq
+
+ def mapGroupsToFlatMapGroups[K, V, U](
+ f: (K, Iterator[V]) => U): (K, Iterator[V]) => Iterator[U] =
+ (key, values) => Iterator.single(f(key, values))
+
+ def mapGroupsWithStateToFlatMapWithState[K, V, S, U](
+ f: (K, Iterator[V], GroupState[S]) => U): (K, Iterator[V], GroupState[S]) => Iterator[U] =
+ (key: K, values: Iterator[V], state: GroupState[S]) => Iterator(f(key, values, state))
+
+ def coGroupWithMappedValues[K, V, U, R, IV, IU](
+ f: (K, Iterator[V], Iterator[U]) => IterableOnce[R],
+ leftValueMapFunc: Option[IV => V],
+ rightValueMapFunc: Option[IU => U]): (K, Iterator[IV], Iterator[IU]) => IterableOnce[R] = {
+ (leftValueMapFunc, rightValueMapFunc) match {
+ case (None, None) =>
+ f.asInstanceOf[(K, Iterator[IV], Iterator[IU]) => IterableOnce[R]]
+ case (Some(mapLeft), None) =>
+ (k, left, right) => f(k, left.map(mapLeft), right.asInstanceOf[Iterator[U]])
+ case (None, Some(mapRight)) =>
+ (k, left, right) => f(k, left.asInstanceOf[Iterator[V]], right.map(mapRight))
+ case (Some(mapLeft), Some(mapRight)) =>
+ (k, left, right) => f(k, left.map(mapLeft), right.map(mapRight))
+ }
+ }
+
+ def flatMapGroupsWithMappedValues[K, IV, V, R](
+ f: (K, Iterator[V]) => IterableOnce[R],
+ valueMapFunc: Option[IV => V]): (K, Iterator[IV]) => IterableOnce[R] = valueMapFunc match {
+ case Some(mapValue) => (k, values) => f(k, values.map(mapValue))
+ case None => f.asInstanceOf[(K, Iterator[IV]) => IterableOnce[R]]
+ }
+
+ def flatMapGroupsWithStateWithMappedValues[K, IV, V, S, U](
+ f: (K, Iterator[V], GroupState[S]) => Iterator[U],
+ valueMapFunc: Option[IV => V]): (K, Iterator[IV], GroupState[S]) => Iterator[U] = {
+ valueMapFunc match {
+ case Some(mapValue) => (k, values, state) => f(k, values.map(mapValue), state)
+ case None => f.asInstanceOf[(K, Iterator[IV], GroupState[S]) => Iterator[U]]
+ }
+ }
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
index 9d77b2e6f3e22..51b26a1fa2435 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala
@@ -32,10 +32,12 @@ import org.apache.spark.util.SparkClassUtils
* implementation specific form (e.g. Catalyst expressions, or Connect protobuf messages).
*
* This API is a mirror image of Connect's expression.proto. There are a couple of extensions to
- * make constructing nodes easier (e.g. [[CaseWhenOtherwise]]). We could not use the actual connect
- * protobuf messages because of classpath clashes (e.g. Guava & gRPC) and Maven shading issues.
+ * make constructing nodes easier (e.g. [[CaseWhenOtherwise]]). We could not use the actual
+ * connect protobuf messages because of classpath clashes (e.g. Guava & gRPC) and Maven shading
+ * issues.
*/
private[sql] trait ColumnNode extends ColumnNodeLike {
+
/**
* Origin where the node was created.
*/
@@ -93,13 +95,17 @@ private[internal] object ColumnNode {
/**
* A literal column.
*
- * @param value of the literal. This is the unconverted input value.
- * @param dataType of the literal. If none is provided the dataType is inferred.
+ * @param value
+ * of the literal. This is the unconverted input value.
+ * @param dataType
+ * of the literal. If none is provided the dataType is inferred.
*/
private[sql] case class Literal(
value: Any,
dataType: Option[DataType] = None,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode with DataTypeErrorsBase {
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode
+ with DataTypeErrorsBase {
override private[internal] def normalize(): Literal = copy(origin = NO_ORIGIN)
override def sql: String = value match {
@@ -116,16 +122,19 @@ private[sql] case class Literal(
/**
* Reference to an attribute produced by one of the underlying DataFrames.
*
- * @param unparsedIdentifier name of the attribute.
- * @param planId id of the plan (Dataframe) that produces the attribute.
- * @param isMetadataColumn whether this is a metadata column.
+ * @param unparsedIdentifier
+ * name of the attribute.
+ * @param planId
+ * id of the plan (Dataframe) that produces the attribute.
+ * @param isMetadataColumn
+ * whether this is a metadata column.
*/
private[sql] case class UnresolvedAttribute(
unparsedIdentifier: String,
planId: Option[Long] = None,
isMetadataColumn: Boolean = false,
override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
+ extends ColumnNode {
override private[internal] def normalize(): UnresolvedAttribute =
copy(planId = None, origin = NO_ORIGIN)
override def sql: String = unparsedIdentifier
@@ -134,14 +143,16 @@ private[sql] case class UnresolvedAttribute(
/**
* Reference to all columns in a namespace (global, a Dataframe, or a nested struct).
*
- * @param unparsedTarget name of the namespace. None if the global namespace is supposed to be used.
- * @param planId id of the plan (Dataframe) that produces the attribute.
+ * @param unparsedTarget
+ * name of the namespace. None if the global namespace is supposed to be used.
+ * @param planId
+ * id of the plan (Dataframe) that produces the attribute.
*/
private[sql] case class UnresolvedStar(
unparsedTarget: Option[String],
planId: Option[Long] = None,
override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
+ extends ColumnNode {
override private[internal] def normalize(): UnresolvedStar =
copy(planId = None, origin = NO_ORIGIN)
override def sql: String = unparsedTarget.map(_ + ".*").getOrElse("*")
@@ -151,10 +162,12 @@ private[sql] case class UnresolvedStar(
* Call a function. This can either be a built-in function, a UDF, or a UDF registered in the
* Catalog.
*
- * @param functionName of the function to invoke.
- * @param arguments to pass into the function.
- * @param isDistinct (aggregate only) whether the input of the aggregate function should be
- * de-duplicated.
+ * @param functionName
+ * of the function to invoke.
+ * @param arguments
+ * to pass into the function.
+ * @param isDistinct
+ * (aggregate only) whether the input of the aggregate function should be de-duplicated.
*/
private[sql] case class UnresolvedFunction(
functionName: String,
@@ -163,7 +176,7 @@ private[sql] case class UnresolvedFunction(
isUserDefinedFunction: Boolean = false,
isInternal: Boolean = false,
override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
+ extends ColumnNode {
override private[internal] def normalize(): UnresolvedFunction =
copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN)
@@ -173,11 +186,13 @@ private[sql] case class UnresolvedFunction(
/**
* Evaluate a SQL expression.
*
- * @param expression text to execute.
+ * @param expression
+ * text to execute.
*/
private[sql] case class SqlExpression(
expression: String,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode {
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
override private[internal] def normalize(): SqlExpression = copy(origin = NO_ORIGIN)
override def sql: String = expression
}
@@ -185,15 +200,19 @@ private[sql] case class SqlExpression(
/**
* Name a column, and (optionally) modify its metadata.
*
- * @param child to name
- * @param name to use
- * @param metadata (optional) metadata to add.
+ * @param child
+ * to name
+ * @param name
+ * to use
+ * @param metadata
+ * (optional) metadata to add.
*/
private[sql] case class Alias(
child: ColumnNode,
name: Seq[String],
metadata: Option[Metadata] = None,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode {
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
override private[internal] def normalize(): Alias =
copy(child = child.normalize(), origin = NO_ORIGIN)
@@ -210,15 +229,19 @@ private[sql] case class Alias(
* Cast the value of a Column to a different [[DataType]]. The behavior of the cast can be
* influenced by the `evalMode`.
*
- * @param child that produces the input value.
- * @param dataType to cast to.
- * @param evalMode (try/ansi/legacy) to use for the cast.
+ * @param child
+ * that produces the input value.
+ * @param dataType
+ * to cast to.
+ * @param evalMode
+ * (try/ansi/legacy) to use for the cast.
*/
private[sql] case class Cast(
child: ColumnNode,
dataType: DataType,
evalMode: Option[Cast.EvalMode] = None,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode {
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
override private[internal] def normalize(): Cast =
copy(child = child.normalize(), origin = NO_ORIGIN)
@@ -237,13 +260,16 @@ private[sql] object Cast {
/**
* Reference to all columns in the global namespace in that match a regex.
*
- * @param regex name of the namespace. None if the global namespace is supposed to be used.
- * @param planId id of the plan (Dataframe) that produces the attribute.
+ * @param regex
+ * name of the namespace. None if the global namespace is supposed to be used.
+ * @param planId
+ * id of the plan (Dataframe) that produces the attribute.
*/
private[sql] case class UnresolvedRegex(
regex: String,
planId: Option[Long] = None,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode {
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
override private[internal] def normalize(): UnresolvedRegex =
copy(planId = None, origin = NO_ORIGIN)
override def sql: String = regex
@@ -252,16 +278,19 @@ private[sql] case class UnresolvedRegex(
/**
* Sort the input column.
*
- * @param child to sort.
- * @param sortDirection to sort in, either Ascending or Descending.
- * @param nullOrdering where to place nulls, either at the begin or the end.
+ * @param child
+ * to sort.
+ * @param sortDirection
+ * to sort in, either Ascending or Descending.
+ * @param nullOrdering
+ * where to place nulls, either at the begin or the end.
*/
private[sql] case class SortOrder(
child: ColumnNode,
sortDirection: SortOrder.SortDirection,
nullOrdering: SortOrder.NullOrdering,
override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
+ extends ColumnNode {
override private[internal] def normalize(): SortOrder =
copy(child = child.normalize(), origin = NO_ORIGIN)
@@ -280,14 +309,16 @@ private[sql] object SortOrder {
/**
* Evaluate a function within a window.
*
- * @param windowFunction function to execute.
- * @param windowSpec of the window.
+ * @param windowFunction
+ * function to execute.
+ * @param windowSpec
+ * of the window.
*/
private[sql] case class Window(
windowFunction: ColumnNode,
windowSpec: WindowSpec,
override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
+ extends ColumnNode {
override private[internal] def normalize(): Window = copy(
windowFunction = windowFunction.normalize(),
windowSpec = windowSpec.normalize(),
@@ -299,7 +330,8 @@ private[sql] case class Window(
private[sql] case class WindowSpec(
partitionColumns: Seq[ColumnNode],
sortColumns: Seq[SortOrder],
- frame: Option[WindowFrame] = None) extends ColumnNodeLike {
+ frame: Option[WindowFrame] = None)
+ extends ColumnNodeLike {
override private[internal] def normalize(): WindowSpec = copy(
partitionColumns = ColumnNode.normalize(partitionColumns),
sortColumns = ColumnNode.normalize(sortColumns),
@@ -317,7 +349,7 @@ private[sql] case class WindowFrame(
frameType: WindowFrame.FrameType,
lower: WindowFrame.FrameBoundary,
upper: WindowFrame.FrameBoundary)
- extends ColumnNodeLike {
+ extends ColumnNodeLike {
override private[internal] def normalize(): WindowFrame =
copy(lower = lower.normalize(), upper = upper.normalize())
override private[internal] def sql: String =
@@ -352,13 +384,16 @@ private[sql] object WindowFrame {
/**
* Lambda function to execute. This typically passed as an argument to a function.
*
- * @param function to execute.
- * @param arguments the bound lambda variables.
+ * @param function
+ * to execute.
+ * @param arguments
+ * the bound lambda variables.
*/
private[sql] case class LambdaFunction(
function: ColumnNode,
arguments: Seq[UnresolvedNamedLambdaVariable],
- override val origin: Origin) extends ColumnNode {
+ override val origin: Origin)
+ extends ColumnNode {
override private[internal] def normalize(): LambdaFunction = copy(
function = function.normalize(),
@@ -382,11 +417,13 @@ object LambdaFunction {
/**
* Variable used in a [[LambdaFunction]].
*
- * @param name of the variable.
+ * @param name
+ * of the variable.
*/
private[sql] case class UnresolvedNamedLambdaVariable(
name: String,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode {
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
override private[internal] def normalize(): UnresolvedNamedLambdaVariable =
copy(origin = NO_ORIGIN)
@@ -413,21 +450,22 @@ object UnresolvedNamedLambdaVariable {
}
/**
- * Extract a value from a complex type. This can be a field from a struct, a value from a map,
- * or an element from an array.
+ * Extract a value from a complex type. This can be a field from a struct, a value from a map, or
+ * an element from an array.
*
- * @param child that produces a complex value.
- * @param extraction that is used to access the complex type. This needs to be a string type for
- * structs and maps, and it needs to be an integer for arrays.
+ * @param child
+ * that produces a complex value.
+ * @param extraction
+ * that is used to access the complex type. This needs to be a string type for structs and maps,
+ * and it needs to be an integer for arrays.
*/
private[sql] case class UnresolvedExtractValue(
child: ColumnNode,
extraction: ColumnNode,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode {
- override private[internal] def normalize(): UnresolvedExtractValue = copy(
- child = child.normalize(),
- extraction = extraction.normalize(),
- origin = NO_ORIGIN)
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
+ override private[internal] def normalize(): UnresolvedExtractValue =
+ copy(child = child.normalize(), extraction = extraction.normalize(), origin = NO_ORIGIN)
override def sql: String = s"${child.sql}[${extraction.sql}]"
}
@@ -435,15 +473,19 @@ private[sql] case class UnresolvedExtractValue(
/**
* Update or drop the field of a struct.
*
- * @param structExpression that will be updated.
- * @param fieldName name of the field to update.
- * @param valueExpression new value of the field. If this is None the field will be dropped.
+ * @param structExpression
+ * that will be updated.
+ * @param fieldName
+ * name of the field to update.
+ * @param valueExpression
+ * new value of the field. If this is None the field will be dropped.
*/
private[sql] case class UpdateFields(
structExpression: ColumnNode,
fieldName: String,
valueExpression: Option[ColumnNode] = None,
- override val origin: Origin = CurrentOrigin.get) extends ColumnNode {
+ override val origin: Origin = CurrentOrigin.get)
+ extends ColumnNode {
override private[internal] def normalize(): UpdateFields = copy(
structExpression = structExpression.normalize(),
valueExpression = ColumnNode.normalize(valueExpression),
@@ -455,18 +497,20 @@ private[sql] case class UpdateFields(
}
/**
- * Evaluate one or more conditional branches. The value of the first branch for which the predicate
- * evalutes to true is returned. If none of the branches evaluate to true, the value of `otherwise`
- * is returned.
+ * Evaluate one or more conditional branches. The value of the first branch for which the
+ * predicate evalutes to true is returned. If none of the branches evaluate to true, the value of
+ * `otherwise` is returned.
*
- * @param branches to evaluate. Each entry if a pair of condition and value.
- * @param otherwise (optional) to evaluate when none of the branches evaluate to true.
+ * @param branches
+ * to evaluate. Each entry if a pair of condition and value.
+ * @param otherwise
+ * (optional) to evaluate when none of the branches evaluate to true.
*/
private[sql] case class CaseWhenOtherwise(
branches: Seq[(ColumnNode, ColumnNode)],
otherwise: Option[ColumnNode] = None,
override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
+ extends ColumnNode {
assert(branches.nonEmpty)
override private[internal] def normalize(): CaseWhenOtherwise = copy(
branches = branches.map(kv => (kv._1.normalize(), kv._2.normalize())),
@@ -483,15 +527,17 @@ private[sql] case class CaseWhenOtherwise(
/**
* Invoke an inline user defined function.
*
- * @param function to invoke.
- * @param arguments to pass into the user defined function.
+ * @param function
+ * to invoke.
+ * @param arguments
+ * to pass into the user defined function.
*/
private[sql] case class InvokeInlineUserDefinedFunction(
function: UserDefinedFunctionLike,
arguments: Seq[ColumnNode],
isDistinct: Boolean = false,
override val origin: Origin = CurrentOrigin.get)
- extends ColumnNode {
+ extends ColumnNode {
override private[internal] def normalize(): InvokeInlineUserDefinedFunction =
copy(arguments = ColumnNode.normalize(arguments), origin = NO_ORIGIN)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
index 406449a337271..5c8c77985bb2c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractArrayType.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.internal.types
import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType}
-
/**
* Use AbstractArrayType(AbstractDataType) for defining expected types for expression parameters.
*/
@@ -30,7 +29,7 @@ case class AbstractArrayType(elementType: AbstractDataType) extends AbstractData
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType] &&
- elementType.acceptsType(other.asInstanceOf[ArrayType].elementType)
+ elementType.acceptsType(other.asInstanceOf[ArrayType].elementType)
}
override private[spark] def simpleString: String = s"array<${elementType.simpleString}>"
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala
index 62f422f6f80a7..32f4341839f01 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractMapType.scala
@@ -19,23 +19,20 @@ package org.apache.spark.sql.internal.types
import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType}
-
/**
- * Use AbstractMapType(AbstractDataType, AbstractDataType)
- * for defining expected types for expression parameters.
+ * Use AbstractMapType(AbstractDataType, AbstractDataType) for defining expected types for
+ * expression parameters.
*/
-case class AbstractMapType(
- keyType: AbstractDataType,
- valueType: AbstractDataType
- ) extends AbstractDataType {
+case class AbstractMapType(keyType: AbstractDataType, valueType: AbstractDataType)
+ extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType =
MapType(keyType.defaultConcreteType, valueType.defaultConcreteType, valueContainsNull = true)
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[MapType] &&
- keyType.acceptsType(other.asInstanceOf[MapType].keyType) &&
- valueType.acceptsType(other.asInstanceOf[MapType].valueType)
+ keyType.acceptsType(other.asInstanceOf[MapType].keyType) &&
+ valueType.acceptsType(other.asInstanceOf[MapType].valueType)
}
override private[spark] def simpleString: String =
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
index 05d1701eff74d..dc4ee013fd189 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala
@@ -51,3 +51,12 @@ case object StringTypeBinaryLcase extends AbstractStringType {
case object StringTypeAnyCollation extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]
}
+
+/**
+ * Use StringTypeNonCSAICollation for expressions supporting all possible collation types except
+ * CS_AI collation types.
+ */
+case object StringTypeNonCSAICollation extends AbstractStringType {
+ override private[sql] def acceptsType(other: DataType): Boolean =
+ other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI
+}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala
index 49dc393f8481b..a0958aceb3b3a 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ExpiredTimerInfo.scala
@@ -22,12 +22,13 @@ import java.io.Serializable
import org.apache.spark.annotation.{Evolving, Experimental}
/**
- * Class used to provide access to expired timer's expiry time. These values
- * are only relevant if the ExpiredTimerInfo is valid.
+ * Class used to provide access to expired timer's expiry time. These values are only relevant if
+ * the ExpiredTimerInfo is valid.
*/
@Experimental
@Evolving
private[sql] trait ExpiredTimerInfo extends Serializable {
+
/**
* Check if provided ExpiredTimerInfo is valid.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
index f08a2fd3cc55c..146990917a3fc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
@@ -27,89 +27,89 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
* `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`.
*
* Detail description on `[map/flatMap]GroupsWithState` operation
- * --------------------------------------------------------------
- * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in `KeyValueGroupedDataset`
- * will invoke the user-given function on each group (defined by the grouping function in
- * `Dataset.groupByKey()`) while maintaining a user-defined per-group state between invocations.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger.
- * That is, in every batch of the `StreamingQuery`,
- * the function will be invoked once for each group that has data in the trigger. Furthermore,
- * if timeout is set, then the function will be invoked on timed-out groups (more detail below).
+ * -------------------------------------------------------------- Both, `mapGroupsWithState` and
+ * `flatMapGroupsWithState` in `KeyValueGroupedDataset` will invoke the user-given function on
+ * each group (defined by the grouping function in `Dataset.groupByKey()`) while maintaining a
+ * user-defined per-group state between invocations. For a static batch Dataset, the function will
+ * be invoked once per group. For a streaming Dataset, the function will be invoked for each group
+ * repeatedly in every trigger. That is, in every batch of the `StreamingQuery`, the function will
+ * be invoked once for each group that has data in the trigger. Furthermore, if timeout is set,
+ * then the function will be invoked on timed-out groups (more detail below).
*
* The function is invoked with the following parameters.
- * - The key of the group.
- * - An iterator containing all the values for this group.
- * - A user-defined state object set by previous invocations of the given function.
+ * - The key of the group.
+ * - An iterator containing all the values for this group.
+ * - A user-defined state object set by previous invocations of the given function.
*
* In case of a batch Dataset, there is only one invocation and the state object will be empty as
- * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
- * is equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have
- * no effect.
+ * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` is
+ * equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have no
+ * effect.
*
* The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the
- * former allows the function to return one and only one record, whereas the latter
- * allows the function to return any number of records (including no records). Furthermore, the
+ * former allows the function to return one and only one record, whereas the latter allows the
+ * function to return any number of records (including no records). Furthermore, the
* `flatMapGroupsWithState` is associated with an operation output mode, which can be either
- * `Append` or `Update`. Semantically, this defines whether the output records of one trigger
- * is effectively replacing the previously output records (from previous triggers) or is appending
- * to the list of previously output records. Essentially, this defines how the Result Table (refer
- * to the semantics in the programming guide) is updated, and allows us to reason about the
- * semantics of later operations.
+ * `Append` or `Update`. Semantically, this defines whether the output records of one trigger is
+ * effectively replacing the previously output records (from previous triggers) or is appending to
+ * the list of previously output records. Essentially, this defines how the Result Table (refer to
+ * the semantics in the programming guide) is updated, and allows us to reason about the semantics
+ * of later operations.
*
- * Important points to note about the function (both mapGroupsWithState and flatMapGroupsWithState).
- * - In a trigger, the function will be called only the groups present in the batch. So do not
- * assume that the function will be called in every trigger for every group that has state.
- * - There is no guaranteed ordering of values in the iterator in the function, neither with
- * batch, nor with streaming Datasets.
- * - All the data will be shuffled before applying the function.
- * - If timeout is set, then the function will also be called with no values.
- * See more details on `GroupStateTimeout` below.
+ * Important points to note about the function (both mapGroupsWithState and
+ * flatMapGroupsWithState).
+ * - In a trigger, the function will be called only the groups present in the batch. So do not
+ * assume that the function will be called in every trigger for every group that has state.
+ * - There is no guaranteed ordering of values in the iterator in the function, neither with
+ * batch, nor with streaming Datasets.
+ * - All the data will be shuffled before applying the function.
+ * - If timeout is set, then the function will also be called with no values. See more details
+ * on `GroupStateTimeout` below.
*
* Important points to note about using `GroupState`.
- * - The value of the state cannot be null. So updating state with null will throw
- * `IllegalArgumentException`.
- * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers.
- * - If `remove()` is called, then `exists()` will return `false`,
- * `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
- * - After that, if `update(newState)` is called, then `exists()` will again return `true`,
- * `get()` and `getOption()`will return the updated value.
+ * - The value of the state cannot be null. So updating state with null will throw
+ * `IllegalArgumentException`.
+ * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers.
+ * - If `remove()` is called, then `exists()` will return `false`, `get()` will throw
+ * `NoSuchElementException` and `getOption()` will return `None`
+ * - After that, if `update(newState)` is called, then `exists()` will again return `true`,
+ * `get()` and `getOption()`will return the updated value.
*
* Important points to note about using `GroupStateTimeout`.
- * - The timeout type is a global param across all the groups (set as `timeout` param in
- * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per
- * group by calling `setTimeout...()` in `GroupState`.
- * - Timeouts can be either based on processing time (i.e.
- * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e.
- * `GroupStateTimeout.EventTimeTimeout`).
- * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling
- * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set
- * duration. Guarantees provided by this timeout with a duration of D ms are as follows:
- * - Timeout will never occur before the clock time has advanced by D ms
- * - Timeout will occur eventually when there is a trigger in the query
- * (i.e. after D ms). So there is no strict upper bound on when the timeout would occur.
- * For example, the trigger interval of the query will affect when the timeout actually occurs.
- * If there is no data in the stream (for any group) for a while, then there will not be
- * any trigger and timeout function call will not occur until there is data.
- * - Since the processing time timeout is based on the clock time, it is affected by the
- * variations in the system clock (i.e. time zone changes, clock skew, etc.).
- * - With `EventTimeTimeout`, the user also has to specify the event time watermark in
- * the query using `Dataset.withWatermark()`. With this setting, data that is older than the
- * watermark is filtered out. The timeout can be set for a group by setting a timeout timestamp
- * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark
- * advances beyond the set timestamp. You can control the timeout delay by two parameters -
- * (i) watermark delay and an additional duration beyond the timestamp in the event (which
- * is guaranteed to be newer than watermark due to the filtering). Guarantees provided by this
- * timeout are as follows:
- * - Timeout will never occur before the watermark has exceeded the set timeout.
- * - Similar to processing time timeouts, there is no strict upper bound on the delay when
- * the timeout actually occurs. The watermark can advance only when there is data in the
- * stream and the event time of the data has actually advanced.
- * - When the timeout occurs for a group, the function is called for that group with no values, and
- * `GroupState.hasTimedOut()` set to true.
- * - The timeout is reset every time the function is called on a group, that is,
- * when the group has new data, or the group has timed out. So the user has to set the timeout
- * duration every time the function is called, otherwise, there will not be any timeout set.
+ * - The timeout type is a global param across all the groups (set as `timeout` param in
+ * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable
+ * per group by calling `setTimeout...()` in `GroupState`.
+ * - Timeouts can be either based on processing time (i.e.
+ * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e.
+ * `GroupStateTimeout.EventTimeTimeout`).
+ * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling
+ * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the
+ * set duration. Guarantees provided by this timeout with a duration of D ms are as follows:
+ * - Timeout will never occur before the clock time has advanced by D ms
+ * - Timeout will occur eventually when there is a trigger in the query (i.e. after D ms). So
+ * there is no strict upper bound on when the timeout would occur. For example, the trigger
+ * interval of the query will affect when the timeout actually occurs. If there is no data
+ * in the stream (for any group) for a while, then there will not be any trigger and timeout
+ * function call will not occur until there is data.
+ * - Since the processing time timeout is based on the clock time, it is affected by the
+ * variations in the system clock (i.e. time zone changes, clock skew, etc.).
+ * - With `EventTimeTimeout`, the user also has to specify the event time watermark in the query
+ * using `Dataset.withWatermark()`. With this setting, data that is older than the watermark
+ * is filtered out. The timeout can be set for a group by setting a timeout timestamp
+ * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark
+ * advances beyond the set timestamp. You can control the timeout delay by two parameters -
+ * (i) watermark delay and an additional duration beyond the timestamp in the event (which is
+ * guaranteed to be newer than watermark due to the filtering). Guarantees provided by this
+ * timeout are as follows:
+ * - Timeout will never occur before the watermark has exceeded the set timeout.
+ * - Similar to processing time timeouts, there is no strict upper bound on the delay when the
+ * timeout actually occurs. The watermark can advance only when there is data in the stream
+ * and the event time of the data has actually advanced.
+ * - When the timeout occurs for a group, the function is called for that group with no values,
+ * and `GroupState.hasTimedOut()` set to true.
+ * - The timeout is reset every time the function is called on a group, that is, when the group
+ * has new data, or the group has timed out. So the user has to set the timeout duration every
+ * time the function is called, otherwise, there will not be any timeout set.
*
* `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument.
* This state will be applied when the first batch of the streaming query is processed. If there
@@ -181,7 +181,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
* state.setTimeoutDuration("1 hour"); // Set the timeout
* }
* ...
-* // return something
+ * // return something
* }
* };
*
@@ -191,8 +191,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
* mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout);
* }}}
*
- * @tparam S User-defined type of the state to be stored for each group. Must be encodable into
- * Spark SQL types (see `Encoder` for more details).
+ * @tparam S
+ * User-defined type of the state to be stored for each group. Must be encodable into Spark SQL
+ * types (see `Encoder` for more details).
* @since 2.2.0
*/
@Experimental
@@ -217,44 +218,48 @@ trait GroupState[S] extends LogicalGroupState[S] {
/**
* Whether the function has been called because the key has timed out.
- * @note This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`.
+ * @note
+ * This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`.
*/
def hasTimedOut: Boolean
-
/**
* Set the timeout duration in ms for this key.
*
- * @note [[GroupStateTimeout Processing time timeout]] must be enabled in
- * `[map/flatMap]GroupsWithState` for calling this method.
- * @note This method has no effect when used in a batch query.
+ * @note
+ * [[GroupStateTimeout Processing time timeout]] must be enabled in
+ * `[map/flatMap]GroupsWithState` for calling this method.
+ * @note
+ * This method has no effect when used in a batch query.
*/
@throws[IllegalArgumentException]("if 'durationMs' is not positive")
@throws[UnsupportedOperationException](
"if processing time timeout has not been enabled in [map|flatMap]GroupsWithState")
def setTimeoutDuration(durationMs: Long): Unit
-
/**
* Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc.
*
- * @note [[GroupStateTimeout Processing time timeout]] must be enabled in
- * `[map/flatMap]GroupsWithState` for calling this method.
- * @note This method has no effect when used in a batch query.
+ * @note
+ * [[GroupStateTimeout Processing time timeout]] must be enabled in
+ * `[map/flatMap]GroupsWithState` for calling this method.
+ * @note
+ * This method has no effect when used in a batch query.
*/
@throws[IllegalArgumentException]("if 'duration' is not a valid duration")
@throws[UnsupportedOperationException](
"if processing time timeout has not been enabled in [map|flatMap]GroupsWithState")
def setTimeoutDuration(duration: String): Unit
-
/**
- * Set the timeout timestamp for this key as milliseconds in epoch time.
- * This timestamp cannot be older than the current watermark.
+ * Set the timeout timestamp for this key as milliseconds in epoch time. This timestamp cannot
+ * be older than the current watermark.
*
- * @note [[GroupStateTimeout Event time timeout]] must be enabled in
- * `[map/flatMap]GroupsWithState` for calling this method.
- * @note This method has no effect when used in a batch query.
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no effect when used in a batch query.
*/
@throws[IllegalArgumentException](
"if 'timestampMs' is not positive or less than the current watermark in a streaming query")
@@ -262,16 +267,16 @@ trait GroupState[S] extends LogicalGroupState[S] {
"if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
def setTimeoutTimestamp(timestampMs: Long): Unit
-
/**
* Set the timeout timestamp for this key as milliseconds in epoch time and an additional
- * duration as a string (e.g. "1 hour", "2 days", etc.).
- * The final timestamp (including the additional duration) cannot be older than the
- * current watermark.
+ * duration as a string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the
+ * additional duration) cannot be older than the current watermark.
*
- * @note [[GroupStateTimeout Event time timeout]] must be enabled in
- * `[map/flatMap]GroupsWithState` for calling this method.
- * @note This method has no side effect when used in a batch query.
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no side effect when used in a batch query.
*/
@throws[IllegalArgumentException](
"if 'additionalDuration' is invalid or the final timeout timestamp is less than " +
@@ -280,56 +285,57 @@ trait GroupState[S] extends LogicalGroupState[S] {
"if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit
-
/**
- * Set the timeout timestamp for this key as a java.sql.Date.
- * This timestamp cannot be older than the current watermark.
+ * Set the timeout timestamp for this key as a java.sql.Date. This timestamp cannot be older
+ * than the current watermark.
*
- * @note [[GroupStateTimeout Event time timeout]] must be enabled in
- * `[map/flatMap]GroupsWithState` for calling this method.
- * @note This method has no side effect when used in a batch query.
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no side effect when used in a batch query.
*/
@throws[UnsupportedOperationException](
"if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
def setTimeoutTimestamp(timestamp: java.sql.Date): Unit
-
/**
- * Set the timeout timestamp for this key as a java.sql.Date and an additional
- * duration as a string (e.g. "1 hour", "2 days", etc.).
- * The final timestamp (including the additional duration) cannot be older than the
- * current watermark.
+ * Set the timeout timestamp for this key as a java.sql.Date and an additional duration as a
+ * string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the additional
+ * duration) cannot be older than the current watermark.
*
- * @note [[GroupStateTimeout Event time timeout]] must be enabled in
- * `[map/flatMap]GroupsWithState` for calling this method.
- * @note This method has no side effect when used in a batch query.
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no side effect when used in a batch query.
*/
@throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
@throws[UnsupportedOperationException](
"if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit
-
/**
* Get the current event time watermark as milliseconds in epoch time.
*
- * @note In a streaming query, this can be called only when watermark is set before calling
- * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1.
- * @note The watermark gets propagated in the end of each query. As a result, this method will
- * return 0 (1970-01-01T00:00:00) for the first micro-batch. If you use this value
- * as a part of the timestamp set in the `setTimeoutTimestamp`, it may lead to the
- * state expiring immediately in the next micro-batch, once the watermark gets the
- * real value from your data.
+ * @note
+ * In a streaming query, this can be called only when watermark is set before calling
+ * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1.
+ * @note
+ * The watermark gets propagated in the end of each query. As a result, this method will
+ * return 0 (1970-01-01T00:00:00) for the first micro-batch. If you use this value as a part
+ * of the timestamp set in the `setTimeoutTimestamp`, it may lead to the state expiring
+ * immediately in the next micro-batch, once the watermark gets the real value from your data.
*/
@throws[UnsupportedOperationException](
"if watermark has not been set before in [map|flatMap]GroupsWithState")
def getCurrentWatermarkMs(): Long
-
/**
* Get the current processing time as milliseconds in epoch time.
- * @note In a streaming query, this will return a constant value throughout the duration of a
- * trigger, even if the trigger is re-executed.
+ * @note
+ * In a streaming query, this will return a constant value throughout the duration of a
+ * trigger, even if the trigger is re-executed.
*/
def getCurrentProcessingTimeMs(): Long
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala
index 0e2d6cc3778c6..568578d1f4ff6 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala
@@ -21,8 +21,7 @@ import org.apache.spark.annotation.{Evolving, Experimental}
@Experimental
@Evolving
/**
- * Interface used for arbitrary stateful operations with the v2 API to capture
- * list value state.
+ * Interface used for arbitrary stateful operations with the v2 API to capture list value state.
*/
private[sql] trait ListState[S] extends Serializable {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala
index 030c3ee989c6f..7b01888bbac49 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/MapState.scala
@@ -21,10 +21,10 @@ import org.apache.spark.annotation.{Evolving, Experimental}
@Experimental
@Evolving
/**
- * Interface used for arbitrary stateful operations with the v2 API to capture
- * map value state.
+ * Interface used for arbitrary stateful operations with the v2 API to capture map value state.
*/
trait MapState[K, V] extends Serializable {
+
/** Whether state exists or not. */
def exists(): Boolean
@@ -35,7 +35,7 @@ trait MapState[K, V] extends Serializable {
def containsKey(key: K): Boolean
/** Update value for given user key */
- def updateValue(key: K, value: V) : Unit
+ def updateValue(key: K, value: V): Unit
/** Get the map associated with grouping key */
def iterator(): Iterator[(K, V)]
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala
index 7754a514fdd6b..f239bcff49fea 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/QueryInfo.scala
@@ -22,8 +22,8 @@ import java.util.UUID
import org.apache.spark.annotation.{Evolving, Experimental}
/**
- * Represents the query info provided to the stateful processor used in the
- * arbitrary state API v2 to easily identify task retries on the same partition.
+ * Represents the query info provided to the stateful processor used in the arbitrary state API v2
+ * to easily identify task retries on the same partition.
*/
@Experimental
@Evolving
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index 54e6a9a4ab678..d2c6010454c55 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -31,31 +31,35 @@ import org.apache.spark.sql.errors.ExecutionErrors
private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable {
/**
- * Handle to the stateful processor that provides access to the state store and other
- * stateful processing related APIs.
+ * Handle to the stateful processor that provides access to the state store and other stateful
+ * processing related APIs.
*/
private var statefulProcessorHandle: StatefulProcessorHandle = null
/**
- * Function that will be invoked as the first method that allows for users to
- * initialize all their state variables and perform other init actions before handling data.
- * @param outputMode - output mode for the stateful processor
- * @param timeMode - time mode for the stateful processor.
+ * Function that will be invoked as the first method that allows for users to initialize all
+ * their state variables and perform other init actions before handling data.
+ * @param outputMode
+ * \- output mode for the stateful processor
+ * @param timeMode
+ * \- time mode for the stateful processor.
*/
- def init(
- outputMode: OutputMode,
- timeMode: TimeMode): Unit
+ def init(outputMode: OutputMode, timeMode: TimeMode): Unit
/**
* Function that will allow users to interact with input data rows along with the grouping key
* and current timer values and optionally provide output rows.
- * @param key - grouping key
- * @param inputRows - iterator of input rows associated with grouping key
- * @param timerValues - instance of TimerValues that provides access to current processing/event
- * time if available
- * @param expiredTimerInfo - instance of ExpiredTimerInfo that provides access to expired timer
- * if applicable
- * @return - Zero or more output rows
+ * @param key
+ * \- grouping key
+ * @param inputRows
+ * \- iterator of input rows associated with grouping key
+ * @param timerValues
+ * \- instance of TimerValues that provides access to current processing/event time if
+ * available
+ * @param expiredTimerInfo
+ * \- instance of ExpiredTimerInfo that provides access to expired timer if applicable
+ * @return
+ * \- Zero or more output rows
*/
def handleInputRows(
key: K,
@@ -64,16 +68,17 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable {
expiredTimerInfo: ExpiredTimerInfo): Iterator[O]
/**
- * Function called as the last method that allows for users to perform
- * any cleanup or teardown operations.
+ * Function called as the last method that allows for users to perform any cleanup or teardown
+ * operations.
*/
- def close (): Unit = {}
+ def close(): Unit = {}
/**
* Function to set the stateful processor handle that will be used to interact with the state
* store and other stateful processor related operations.
*
- * @param handle - instance of StatefulProcessorHandle
+ * @param handle
+ * \- instance of StatefulProcessorHandle
*/
final def setHandle(handle: StatefulProcessorHandle): Unit = {
statefulProcessorHandle = handle
@@ -82,7 +87,8 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable {
/**
* Function to get the stateful processor handle that will be used to interact with the state
*
- * @return handle - instance of StatefulProcessorHandle
+ * @return
+ * handle - instance of StatefulProcessorHandle
*/
final def getHandle: StatefulProcessorHandle = {
if (statefulProcessorHandle == null) {
@@ -93,23 +99,25 @@ private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable {
}
/**
- * Stateful processor with support for specifying initial state.
- * Accepts a user-defined type as initial state to be initialized in the first batch.
- * This can be used for starting a new streaming query with existing state from a
- * previous streaming query.
+ * Stateful processor with support for specifying initial state. Accepts a user-defined type as
+ * initial state to be initialized in the first batch. This can be used for starting a new
+ * streaming query with existing state from a previous streaming query.
*/
@Experimental
@Evolving
private[sql] abstract class StatefulProcessorWithInitialState[K, I, O, S]
- extends StatefulProcessor[K, I, O] {
+ extends StatefulProcessor[K, I, O] {
/**
* Function that will be invoked only in the first batch for users to process initial states.
*
- * @param key - grouping key
- * @param initialState - A row in the initial state to be processed
- * @param timerValues - instance of TimerValues that provides access to current processing/event
- * time if available
+ * @param key
+ * \- grouping key
+ * @param initialState
+ * \- A row in the initial state to be processed
+ * @param timerValues
+ * \- instance of TimerValues that provides access to current processing/event time if
+ * available
*/
def handleInitialState(key: K, initialState: S, timerValues: TimerValues): Unit
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index 4dc2ca875ef0e..d1eca0f3967d9 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -22,39 +22,46 @@ import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.Encoder
/**
- * Represents the operation handle provided to the stateful processor used in the
- * arbitrary state API v2.
+ * Represents the operation handle provided to the stateful processor used in the arbitrary state
+ * API v2.
*/
@Experimental
@Evolving
private[sql] trait StatefulProcessorHandle extends Serializable {
/**
- * Function to create new or return existing single value state variable of given type.
- * The user must ensure to call this function only within the `init()` method of the
- * StatefulProcessor.
+ * Function to create new or return existing single value state variable of given type. The user
+ * must ensure to call this function only within the `init()` method of the StatefulProcessor.
*
- * @param stateName - name of the state variable
- * @param valEncoder - SQL encoder for state variable
- * @tparam T - type of state variable
- * @return - instance of ValueState of type T that can be used to store state persistently
+ * @param stateName
+ * \- name of the state variable
+ * @param valEncoder
+ * \- SQL encoder for state variable
+ * @tparam T
+ * \- type of state variable
+ * @return
+ * \- instance of ValueState of type T that can be used to store state persistently
*/
def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T]
/**
- * Function to create new or return existing single value state variable of given type
- * with ttl. State values will not be returned past ttlDuration, and will be eventually removed
- * from the state store. Any state update resets the ttl to current processing time plus
- * ttlDuration.
+ * Function to create new or return existing single value state variable of given type with ttl.
+ * State values will not be returned past ttlDuration, and will be eventually removed from the
+ * state store. Any state update resets the ttl to current processing time plus ttlDuration.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
- * @param stateName - name of the state variable
- * @param valEncoder - SQL encoder for state variable
- * @param ttlConfig - the ttl configuration (time to live duration etc.)
- * @tparam T - type of state variable
- * @return - instance of ValueState of type T that can be used to store state persistently
+ * @param stateName
+ * \- name of the state variable
+ * @param valEncoder
+ * \- SQL encoder for state variable
+ * @param ttlConfig
+ * \- the ttl configuration (time to live duration etc.)
+ * @tparam T
+ * \- type of state variable
+ * @return
+ * \- instance of ValueState of type T that can be used to store state persistently
*/
def getValueState[T](
stateName: String,
@@ -62,30 +69,39 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
ttlConfig: TTLConfig): ValueState[T]
/**
- * Creates new or returns existing list state associated with stateName.
- * The ListState persists values of type T.
+ * Creates new or returns existing list state associated with stateName. The ListState persists
+ * values of type T.
*
- * @param stateName - name of the state variable
- * @param valEncoder - SQL encoder for state variable
- * @tparam T - type of state variable
- * @return - instance of ListState of type T that can be used to store state persistently
+ * @param stateName
+ * \- name of the state variable
+ * @param valEncoder
+ * \- SQL encoder for state variable
+ * @tparam T
+ * \- type of state variable
+ * @return
+ * \- instance of ListState of type T that can be used to store state persistently
*/
def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]
/**
- * Function to create new or return existing list state variable of given type
- * with ttl. State values will not be returned past ttlDuration, and will be eventually removed
- * from the state store. Any values in listState which have expired after ttlDuration will not
- * be returned on get() and will be eventually removed from the state.
+ * Function to create new or return existing list state variable of given type with ttl. State
+ * values will not be returned past ttlDuration, and will be eventually removed from the state
+ * store. Any values in listState which have expired after ttlDuration will not be returned on
+ * get() and will be eventually removed from the state.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
- * @param stateName - name of the state variable
- * @param valEncoder - SQL encoder for state variable
- * @param ttlConfig - the ttl configuration (time to live duration etc.)
- * @tparam T - type of state variable
- * @return - instance of ListState of type T that can be used to store state persistently
+ * @param stateName
+ * \- name of the state variable
+ * @param valEncoder
+ * \- SQL encoder for state variable
+ * @param ttlConfig
+ * \- the ttl configuration (time to live duration etc.)
+ * @tparam T
+ * \- type of state variable
+ * @return
+ * \- instance of ListState of type T that can be used to store state persistently
*/
def getListState[T](
stateName: String,
@@ -93,15 +109,21 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
ttlConfig: TTLConfig): ListState[T]
/**
- * Creates new or returns existing map state associated with stateName.
- * The MapState persists Key-Value pairs of type [K, V].
+ * Creates new or returns existing map state associated with stateName. The MapState persists
+ * Key-Value pairs of type [K, V].
*
- * @param stateName - name of the state variable
- * @param userKeyEnc - spark sql encoder for the map key
- * @param valEncoder - spark sql encoder for the map value
- * @tparam K - type of key for map state variable
- * @tparam V - type of value for map state variable
- * @return - instance of MapState of type [K,V] that can be used to store state persistently
+ * @param stateName
+ * \- name of the state variable
+ * @param userKeyEnc
+ * \- spark sql encoder for the map key
+ * @param valEncoder
+ * \- spark sql encoder for the map value
+ * @tparam K
+ * \- type of key for map state variable
+ * @tparam V
+ * \- type of value for map state variable
+ * @return
+ * \- instance of MapState of type [K,V] that can be used to store state persistently
*/
def getMapState[K, V](
stateName: String,
@@ -109,57 +131,68 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
valEncoder: Encoder[V]): MapState[K, V]
/**
- * Function to create new or return existing map state variable of given type
- * with ttl. State values will not be returned past ttlDuration, and will be eventually removed
- * from the state store. Any values in mapState which have expired after ttlDuration will not
- * returned on get() and will be eventually removed from the state.
+ * Function to create new or return existing map state variable of given type with ttl. State
+ * values will not be returned past ttlDuration, and will be eventually removed from the state
+ * store. Any values in mapState which have expired after ttlDuration will not returned on get()
+ * and will be eventually removed from the state.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
- * @param stateName - name of the state variable
- * @param userKeyEnc - spark sql encoder for the map key
- * @param valEncoder - SQL encoder for state variable
- * @param ttlConfig - the ttl configuration (time to live duration etc.)
- * @tparam K - type of key for map state variable
- * @tparam V - type of value for map state variable
- * @return - instance of MapState of type [K,V] that can be used to store state persistently
+ * @param stateName
+ * \- name of the state variable
+ * @param userKeyEnc
+ * \- spark sql encoder for the map key
+ * @param valEncoder
+ * \- SQL encoder for state variable
+ * @param ttlConfig
+ * \- the ttl configuration (time to live duration etc.)
+ * @tparam K
+ * \- type of key for map state variable
+ * @tparam V
+ * \- type of value for map state variable
+ * @return
+ * \- instance of MapState of type [K,V] that can be used to store state persistently
*/
def getMapState[K, V](
- stateName: String,
- userKeyEnc: Encoder[K],
- valEncoder: Encoder[V],
- ttlConfig: TTLConfig): MapState[K, V]
+ stateName: String,
+ userKeyEnc: Encoder[K],
+ valEncoder: Encoder[V],
+ ttlConfig: TTLConfig): MapState[K, V]
/** Function to return queryInfo for currently running task */
def getQueryInfo(): QueryInfo
/**
- * Function to register a processing/event time based timer for given implicit grouping key
- * and provided timestamp
- * @param expiryTimestampMs - timer expiry timestamp in milliseconds
+ * Function to register a processing/event time based timer for given implicit grouping key and
+ * provided timestamp
+ * @param expiryTimestampMs
+ * \- timer expiry timestamp in milliseconds
*/
def registerTimer(expiryTimestampMs: Long): Unit
/**
- * Function to delete a processing/event time based timer for given implicit grouping key
- * and provided timestamp
- * @param expiryTimestampMs - timer expiry timestamp in milliseconds
+ * Function to delete a processing/event time based timer for given implicit grouping key and
+ * provided timestamp
+ * @param expiryTimestampMs
+ * \- timer expiry timestamp in milliseconds
*/
def deleteTimer(expiryTimestampMs: Long): Unit
/**
- * Function to list all the timers registered for given implicit grouping key
- * Note: calling listTimers() within the `handleInputRows` method of the StatefulProcessor
- * will return all the unprocessed registered timers, including the one being fired within the
- * invocation of `handleInputRows`.
- * @return - list of all the registered timers for given implicit grouping key
+ * Function to list all the timers registered for given implicit grouping key Note: calling
+ * listTimers() within the `handleInputRows` method of the StatefulProcessor will return all the
+ * unprocessed registered timers, including the one being fired within the invocation of
+ * `handleInputRows`.
+ * @return
+ * \- list of all the registered timers for given implicit grouping key
*/
def listTimers(): Iterator[Long]
/**
* Function to delete and purge state variable if defined previously
- * @param stateName - name of the state variable
+ * @param stateName
+ * \- name of the state variable
*/
def deleteIfExists(stateName: String): Unit
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
similarity index 63%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index fcb4bdcb327bc..a6684969ff1ec 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -20,19 +20,24 @@ package org.apache.spark.sql.streaming
import java.util.UUID
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
+import com.fasterxml.jackson.databind.node.TreeTraversingParser
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
-import org.json4s.{JObject, JString, JValue}
+import org.json4s.{JObject, JString}
+import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.annotation.Evolving
+import org.apache.spark.scheduler.SparkListenerEvent
/**
- * Interface for listening to events related to [[StreamingQuery StreamingQueries]].
+ * Interface for listening to events related to
+ * [[org.apache.spark.sql.api.StreamingQuery StreamingQueries]].
+ *
* @note
* The methods are not thread-safe as they may be called from different threads.
*
- * @since 3.5.0
+ * @since 2.0.0
*/
@Evolving
abstract class StreamingQueryListener extends Serializable {
@@ -42,12 +47,11 @@ abstract class StreamingQueryListener extends Serializable {
/**
* Called when a query is started.
* @note
- * This is called synchronously with
- * [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]], that is,
- * `onQueryStart` will be called on all listeners before `DataStreamWriter.start()` returns
- * the corresponding [[StreamingQuery]]. Please don't block this method as it will block your
- * query.
- * @since 3.5.0
+ * This is called synchronously with `DataStreamWriter.start()`, that is, `onQueryStart` will
+ * be called on all listeners before `DataStreamWriter.start()` returns the corresponding
+ * [[org.apache.spark.sql.api.StreamingQuery]]. Please don't block this method as it will
+ * block your query.
+ * @since 2.0.0
*/
def onQueryStarted(event: QueryStartedEvent): Unit
@@ -55,11 +59,12 @@ abstract class StreamingQueryListener extends Serializable {
* Called when there is some status update (ingestion rate updated, etc.)
*
* @note
- * This method is asynchronous. The status in [[StreamingQuery]] will always be latest no
- * matter when this method is called. Therefore, the status of [[StreamingQuery]] may be
- * changed before/when you process the event. E.g., you may find [[StreamingQuery]] is
- * terminated when you are processing `QueryProgressEvent`.
- * @since 3.5.0
+ * This method is asynchronous. The status in [[org.apache.spark.sql.api.StreamingQuery]] will
+ * always be latest no matter when this method is called. Therefore, the status of
+ * [[org.apache.spark.sql.api.StreamingQuery]] may be changed before/when you process the
+ * event. E.g., you may find [[org.apache.spark.sql.api.StreamingQuery]] is terminated when
+ * you are processing `QueryProgressEvent`.
+ * @since 2.0.0
*/
def onQueryProgress(event: QueryProgressEvent): Unit
@@ -71,24 +76,68 @@ abstract class StreamingQueryListener extends Serializable {
/**
* Called when a query is stopped, with or without error.
- * @since 3.5.0
+ * @since 2.0.0
*/
def onQueryTerminated(event: QueryTerminatedEvent): Unit
}
+/**
+ * Py4J allows a pure interface so this proxy is required.
+ */
+private[spark] trait PythonStreamingQueryListener {
+ import StreamingQueryListener._
+
+ def onQueryStarted(event: QueryStartedEvent): Unit
+
+ def onQueryProgress(event: QueryProgressEvent): Unit
+
+ def onQueryIdle(event: QueryIdleEvent): Unit
+
+ def onQueryTerminated(event: QueryTerminatedEvent): Unit
+}
+
+private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener)
+ extends StreamingQueryListener {
+ import StreamingQueryListener._
+
+ def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event)
+
+ def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event)
+
+ override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event)
+
+ def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event)
+}
+
/**
* Companion object of [[StreamingQueryListener]] that defines the listener events.
- * @since 3.5.0
+ * @since 2.0.0
*/
@Evolving
object StreamingQueryListener extends Serializable {
+ private val mapper = {
+ val ret = new ObjectMapper() with ClassTagExtensions
+ ret.registerModule(DefaultScalaModule)
+ ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
+ ret
+ }
+
+ private case class EventParser(json: String) {
+ private val tree = mapper.readTree(json)
+ def getString(name: String): String = tree.get(name).asText()
+ def getUUID(name: String): UUID = UUID.fromString(getString(name))
+ def getProgress(name: String): StreamingQueryProgress = {
+ val parser = new TreeTraversingParser(tree.get(name), mapper)
+ parser.readValueAs(classOf[StreamingQueryProgress])
+ }
+ }
/**
* Base type of [[StreamingQueryListener]] events
- * @since 3.5.0
+ * @since 2.0.0
*/
@Evolving
- trait Event
+ trait Event extends SparkListenerEvent
/**
* Event representing the start of a query
@@ -100,7 +149,7 @@ object StreamingQueryListener extends Serializable {
* User-specified name of the query, null if not specified.
* @param timestamp
* The timestamp to start a query.
- * @since 3.5.0
+ * @since 2.1.0
*/
@Evolving
class QueryStartedEvent private[sql] (
@@ -122,25 +171,22 @@ object StreamingQueryListener extends Serializable {
}
private[spark] object QueryStartedEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
- private[spark] def jsonString(event: QueryStartedEvent): String =
- mapper.writeValueAsString(event)
-
- private[spark] def fromJson(json: String): QueryStartedEvent =
- mapper.readValue[QueryStartedEvent](json)
+ private[spark] def fromJson(json: String): QueryStartedEvent = {
+ val parser = EventParser(json)
+ new QueryStartedEvent(
+ parser.getUUID("id"),
+ parser.getUUID("runId"),
+ parser.getString("name"),
+ parser.getString("name"))
+ }
}
/**
* Event representing any progress updates in a query.
* @param progress
* The query progress updates.
- * @since 3.5.0
+ * @since 2.1.0
*/
@Evolving
class QueryProgressEvent private[sql] (val progress: StreamingQueryProgress)
@@ -153,18 +199,11 @@ object StreamingQueryListener extends Serializable {
}
private[spark] object QueryProgressEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
-
- private[spark] def jsonString(event: QueryProgressEvent): String =
- mapper.writeValueAsString(event)
- private[spark] def fromJson(json: String): QueryProgressEvent =
- mapper.readValue[QueryProgressEvent](json)
+ private[spark] def fromJson(json: String): QueryProgressEvent = {
+ val parser = EventParser(json)
+ new QueryProgressEvent(parser.getProgress("progress"))
+ }
}
/**
@@ -193,18 +232,14 @@ object StreamingQueryListener extends Serializable {
}
private[spark] object QueryIdleEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
- private[spark] def jsonString(event: QueryTerminatedEvent): String =
- mapper.writeValueAsString(event)
-
- private[spark] def fromJson(json: String): QueryTerminatedEvent =
- mapper.readValue[QueryTerminatedEvent](json)
+ private[spark] def fromJson(json: String): QueryIdleEvent = {
+ val parser = EventParser(json)
+ new QueryIdleEvent(
+ parser.getUUID("id"),
+ parser.getUUID("runId"),
+ parser.getString("timestamp"))
+ }
}
/**
@@ -221,7 +256,7 @@ object StreamingQueryListener extends Serializable {
* The error class from the exception if the query was terminated with an exception which is a
* part of error class framework. If the query was terminated without an exception, or the
* exception is not a part of error class framework, it will be `None`.
- * @since 3.5.0
+ * @since 2.1.0
*/
@Evolving
class QueryTerminatedEvent private[sql] (
@@ -247,17 +282,13 @@ object StreamingQueryListener extends Serializable {
}
private[spark] object QueryTerminatedEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
+ private[spark] def fromJson(json: String): QueryTerminatedEvent = {
+ val parser = EventParser(json)
+ new QueryTerminatedEvent(
+ parser.getUUID("id"),
+ parser.getUUID("runId"),
+ Option(parser.getString("exception")),
+ Option(parser.getString("errorClassOnException")))
}
-
- private[spark] def jsonString(event: QueryTerminatedEvent): String =
- mapper.writeValueAsString(event)
-
- private[spark] def fromJson(json: String): QueryTerminatedEvent =
- mapper.readValue[QueryTerminatedEvent](json)
}
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
similarity index 95%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
index cdda25876b250..c37cdd00c8866 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
@@ -36,7 +36,7 @@ import org.apache.spark.annotation.Evolving
* True when the trigger is actively firing, false when waiting for the next trigger time.
* Doesn't apply to ContinuousExecution where it is always false.
*
- * @since 3.5.0
+ * @since 2.1.0
*/
@Evolving
class StreamingQueryStatus protected[sql] (
@@ -44,7 +44,6 @@ class StreamingQueryStatus protected[sql] (
val isDataAvailable: Boolean,
val isTriggerActive: Boolean)
extends Serializable {
- // This is a copy of the same class in sql/core/.../streaming/StreamingQueryStatus.scala
/** The compact JSON representation of this status. */
def json: String = compact(render(jsonValue))
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
index 576e09d5d7fe2..ce786aa943d89 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.streaming
import java.time.Duration
/**
- * TTL Configuration for state variable. State values will not be returned past ttlDuration,
- * and will be eventually removed from the state store. Any state update resets the ttl to
- * current processing time plus ttlDuration.
+ * TTL Configuration for state variable. State values will not be returned past ttlDuration, and
+ * will be eventually removed from the state store. Any state update resets the ttl to current
+ * processing time plus ttlDuration.
*
- * @param ttlDuration time to live duration for state
- * stored in the state variable.
+ * @param ttlDuration
+ * time to live duration for state stored in the state variable.
*/
case class TTLConfig(ttlDuration: Duration)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala
index f0aef58228188..04c5f59428f7f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala
@@ -22,25 +22,29 @@ import java.io.Serializable
import org.apache.spark.annotation.{Evolving, Experimental}
/**
- * Class used to provide access to timer values for processing and event time populated
- * before method invocations using the arbitrary state API v2.
+ * Class used to provide access to timer values for processing and event time populated before
+ * method invocations using the arbitrary state API v2.
*/
@Experimental
@Evolving
private[sql] trait TimerValues extends Serializable {
+
/**
* Get the current processing time as milliseconds in epoch time.
- * @note This will return a constant value throughout the duration of a streaming query trigger,
- * even if the trigger is re-executed.
+ * @note
+ * This will return a constant value throughout the duration of a streaming query trigger,
+ * even if the trigger is re-executed.
*/
def getCurrentProcessingTimeInMs(): Long
/**
* Get the current event time watermark as milliseconds in epoch time.
*
- * @note This can be called only when watermark is set before calling `transformWithState`.
- * @note The watermark gets propagated at the end of each query. As a result, this method will
- * return 0 (1970-01-01T00:00:00) for the first micro-batch.
+ * @note
+ * This can be called only when watermark is set before calling `transformWithState`.
+ * @note
+ * The watermark gets propagated at the end of each query. As a result, this method will
+ * return 0 (1970-01-01T00:00:00) for the first micro-batch.
*/
def getCurrentWatermarkInMs(): Long
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
index 8a2661e1a55b1..edb5f65365ab8 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
@@ -24,8 +24,7 @@ import org.apache.spark.annotation.{Evolving, Experimental}
@Experimental
@Evolving
/**
- * Interface used for arbitrary stateful operations with the v2 API to capture
- * single value state.
+ * Interface used for arbitrary stateful operations with the v2 API to capture single value state.
*/
private[sql] trait ValueState[S] extends Serializable {
@@ -34,7 +33,8 @@ private[sql] trait ValueState[S] extends Serializable {
/**
* Get the state value if it exists
- * @throws java.util.NoSuchElementException if the state does not exist
+ * @throws java.util.NoSuchElementException
+ * if the state does not exist
*/
@throws[NoSuchElementException]
def get(): S
@@ -45,7 +45,8 @@ private[sql] trait ValueState[S] extends Serializable {
/**
* Update the value of the state.
*
- * @param newState the new value
+ * @param newState
+ * the new value
*/
def update(newState: S): Unit
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala
similarity index 99%
rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index ebd13bc248f97..b7573cb280444 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -139,7 +139,7 @@ class StateOperatorProgress private[spark] (
* Information about operators in the query that store state.
* @param sources
* detailed statistics on data being read from each of the streaming sources.
- * @since 3.5.0
+ * @since 2.1.0
*/
@Evolving
class StreamingQueryProgress private[spark] (
@@ -195,7 +195,7 @@ class StreamingQueryProgress private[spark] (
}
private[spark] object StreamingQueryProgress {
- private val mapper = {
+ private[this] val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
ret.registerModule(DefaultScalaModule)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
@@ -227,7 +227,7 @@ private[spark] object StreamingQueryProgress {
* The rate at which data is arriving from this source.
* @param processedRowsPerSecond
* The rate at which data from this source is being processed by Spark.
- * @since 3.5.0
+ * @since 2.1.0
*/
@Evolving
class SourceProgress protected[spark] (
@@ -276,7 +276,7 @@ class SourceProgress protected[spark] (
* @param numOutputRows
* Number of rows written to the sink or -1 for Continuous Mode (temporarily) or Sink V1 (until
* decommissioned).
- * @since 3.5.0
+ * @since 2.1.0
*/
@Evolving
class SinkProgress protected[spark] (
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 67f634f8379cd..9590fb23e16b1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.errors.DataTypeErrors
* A non-concrete data type, reserved for internal uses.
*/
private[sql] abstract class AbstractDataType {
+
/**
* The default concrete type to use if we want to cast a null literal into this type.
*/
@@ -47,7 +48,6 @@ private[sql] abstract class AbstractDataType {
private[sql] def simpleString: String
}
-
/**
* A collection of types that can be used to specify type constraints. The sequence also specifies
* precedence: an earlier type takes precedence over a latter type.
@@ -59,7 +59,7 @@ private[sql] abstract class AbstractDataType {
* This means that we prefer StringType over BinaryType if it is possible to cast to StringType.
*/
private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
- extends AbstractDataType {
+ extends AbstractDataType {
require(types.nonEmpty, s"TypeCollection ($types) cannot be empty")
@@ -73,22 +73,20 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
}
}
-
private[sql] object TypeCollection {
/**
* Types that include numeric types and ANSI interval types.
*/
- val NumericAndAnsiInterval = TypeCollection(
- NumericType,
- DayTimeIntervalType,
- YearMonthIntervalType)
+ val NumericAndAnsiInterval =
+ TypeCollection(NumericType, DayTimeIntervalType, YearMonthIntervalType)
/**
- * Types that include numeric and ANSI interval types, and additionally the legacy interval type.
- * They are only used in unary_minus, unary_positive, add and subtract operations.
+ * Types that include numeric and ANSI interval types, and additionally the legacy interval
+ * type. They are only used in unary_minus, unary_positive, add and subtract operations.
*/
- val NumericAndInterval = new TypeCollection(NumericAndAnsiInterval.types :+ CalendarIntervalType)
+ val NumericAndInterval = new TypeCollection(
+ NumericAndAnsiInterval.types :+ CalendarIntervalType)
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
@@ -98,7 +96,6 @@ private[sql] object TypeCollection {
}
}
-
/**
* An `AbstractDataType` that matches any concrete data types.
*/
@@ -114,15 +111,14 @@ protected[sql] object AnyDataType extends AbstractDataType with Serializable {
override private[sql] def acceptsType(other: DataType): Boolean = true
}
-
/**
- * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.
+ * An internal type used to represent everything that is not null, UDTs, arrays, structs, and
+ * maps.
*/
protected[sql] abstract class AtomicType extends DataType
object AtomicType
-
/**
* Numeric data types.
*
@@ -131,7 +127,6 @@ object AtomicType
@Stable
abstract class NumericType extends AtomicType
-
private[spark] object NumericType extends AbstractDataType {
override private[spark] def defaultConcreteType: DataType = DoubleType
@@ -141,22 +136,19 @@ private[spark] object NumericType extends AbstractDataType {
other.isInstanceOf[NumericType]
}
-
private[sql] object IntegralType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = IntegerType
override private[sql] def simpleString: String = "integral"
- override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
+ override private[sql] def acceptsType(other: DataType): Boolean =
+ other.isInstanceOf[IntegralType]
}
-
private[sql] abstract class IntegralType extends NumericType
-
private[sql] object FractionalType
-
private[sql] abstract class FractionalType extends NumericType
private[sql] object AnyTimestampType extends AbstractDataType with Serializable {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index e5af472d90e25..fc32248b4baf3 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -29,12 +29,14 @@ import org.apache.spark.sql.catalyst.util.StringConcat
*/
@Stable
object ArrayType extends AbstractDataType {
+
/**
* Construct a [[ArrayType]] object with the given element type. The `containsNull` is true.
*/
def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true)
- override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+ override private[sql] def defaultConcreteType: DataType =
+ ArrayType(NullType, containsNull = true)
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType]
@@ -44,18 +46,19 @@ object ArrayType extends AbstractDataType {
}
/**
- * The data type for collections of multiple values.
- * Internally these are represented as columns that contain a ``scala.collection.Seq``.
+ * The data type for collections of multiple values. Internally these are represented as columns
+ * that contain a ``scala.collection.Seq``.
*
* Please use `DataTypes.createArrayType()` to create a specific instance.
*
- * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and
- * `containsNull: Boolean`.
- * The field of `elementType` is used to specify the type of array elements.
- * The field of `containsNull` is used to specify if the array can have `null` values.
+ * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and `containsNull:
+ * Boolean`. The field of `elementType` is used to specify the type of array elements. The field
+ * of `containsNull` is used to specify if the array can have `null` values.
*
- * @param elementType The data type of values.
- * @param containsNull Indicates if the array can have `null` values
+ * @param elementType
+ * The data type of values.
+ * @param containsNull
+ * Indicates if the array can have `null` values
*
* @since 1.3.0
*/
@@ -82,8 +85,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
("containsNull" -> containsNull)
/**
- * The default size of a value of the ArrayType is the default size of the element type.
- * We assume that there is only 1 element on average in an array. See SPARK-18853.
+ * The default size of a value of the ArrayType is the default size of the element type. We
+ * assume that there is only 1 element on average in an array. See SPARK-18853.
*/
override def defaultSize: Int = 1 * elementType.defaultSize
@@ -97,8 +100,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
ArrayType(elementType.asNullable, containsNull = true)
/**
- * Returns the same data type but set all nullability fields are true
- * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ * Returns the same data type but set all nullability fields are true (`StructField.nullable`,
+ * `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*
* @since 4.0.0
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
index c280f66f943aa..20bfd9bf5684f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.types
import org.apache.spark.annotation.Stable
/**
- * The data type representing `Array[Byte]` values.
- * Please use the singleton `DataTypes.BinaryType`.
+ * The data type representing `Array[Byte]` values. Please use the singleton
+ * `DataTypes.BinaryType`.
*/
@Stable
-class BinaryType private() extends AtomicType {
+class BinaryType private () extends AtomicType {
+
/**
* The default size of a value of the BinaryType is 100 bytes.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
index 836c41a996ac4..090c56eaf9af7 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
@@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class BooleanType private() extends AtomicType {
+class BooleanType private () extends AtomicType {
+
/**
* The default size of a value of the BooleanType is 1 byte.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala
index 546ac02f2639a..4a27a00dacb8a 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ByteType.scala
@@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class ByteType private() extends IntegralType {
+class ByteType private () extends IntegralType {
+
/**
* The default size of a value of the ByteType is 1 byte.
*/
@@ -36,7 +37,6 @@ class ByteType private() extends IntegralType {
private[spark] override def asNullable: ByteType = this
}
-
/**
* @since 1.3.0
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
index d506a1521e183..f6b6256db0417 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala
@@ -21,19 +21,19 @@ import org.apache.spark.annotation.Stable
/**
* The data type representing calendar intervals. The calendar interval is stored internally in
- * three components:
- * an integer value representing the number of `months` in this interval,
- * an integer value representing the number of `days` in this interval,
- * a long value representing the number of `microseconds` in this interval.
+ * three components: an integer value representing the number of `months` in this interval, an
+ * integer value representing the number of `days` in this interval, a long value representing the
+ * number of `microseconds` in this interval.
*
* Please use the singleton `DataTypes.CalendarIntervalType` to refer the type.
*
- * @note Calendar intervals are not comparable.
+ * @note
+ * Calendar intervals are not comparable.
*
* @since 1.5.0
*/
@Stable
-class CalendarIntervalType private() extends DataType {
+class CalendarIntervalType private () extends DataType {
override def defaultSize: Int = 16
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 277d5c9458d6f..008c9cd07076c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -49,6 +49,7 @@ import org.apache.spark.util.SparkClassUtils
@JsonSerialize(using = classOf[DataTypeJsonSerializer])
@JsonDeserialize(using = classOf[DataTypeJsonDeserializer])
abstract class DataType extends AbstractDataType {
+
/**
* The default size of a value of this data type, used internally for size estimation.
*/
@@ -57,7 +58,9 @@ abstract class DataType extends AbstractDataType {
/** Name of the type used in JSON serialization. */
def typeName: String = {
this.getClass.getSimpleName
- .stripSuffix("$").stripSuffix("Type").stripSuffix("UDT")
+ .stripSuffix("$")
+ .stripSuffix("Type")
+ .stripSuffix("UDT")
.toLowerCase(Locale.ROOT)
}
@@ -92,8 +95,8 @@ abstract class DataType extends AbstractDataType {
}
/**
- * Returns the same data type but set all nullability fields are true
- * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ * Returns the same data type but set all nullability fields are true (`StructField.nullable`,
+ * `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*/
private[spark] def asNullable: DataType
@@ -107,7 +110,6 @@ abstract class DataType extends AbstractDataType {
override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
}
-
/**
* @since 1.3.0
*/
@@ -128,14 +130,18 @@ object DataType {
}
/**
- * Parses data type from a string with schema. It calls `parser` for `schema`.
- * If it fails, calls `fallbackParser`. If the fallback function fails too, combines error message
- * from `parser` and `fallbackParser`.
+ * Parses data type from a string with schema. It calls `parser` for `schema`. If it fails,
+ * calls `fallbackParser`. If the fallback function fails too, combines error message from
+ * `parser` and `fallbackParser`.
*
- * @param schema The schema string to parse by `parser` or `fallbackParser`.
- * @param parser The function that should be invoke firstly.
- * @param fallbackParser The function that is called when `parser` fails.
- * @return The data type parsed from the `schema` schema.
+ * @param schema
+ * The schema string to parse by `parser` or `fallbackParser`.
+ * @param parser
+ * The function that should be invoke firstly.
+ * @param fallbackParser
+ * The function that is called when `parser` fails.
+ * @return
+ * The data type parsed from the `schema` schema.
*/
def parseTypeWithFallback(
schema: String,
@@ -161,8 +167,20 @@ object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
private val otherTypes = {
- Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
- DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType,
+ Seq(
+ NullType,
+ DateType,
+ TimestampType,
+ BinaryType,
+ IntegerType,
+ BooleanType,
+ LongType,
+ DoubleType,
+ FloatType,
+ ShortType,
+ ByteType,
+ StringType,
+ CalendarIntervalType,
DayTimeIntervalType(DAY),
DayTimeIntervalType(DAY, HOUR),
DayTimeIntervalType(DAY, MINUTE),
@@ -178,7 +196,8 @@ object DataType {
YearMonthIntervalType(YEAR, MONTH),
TimestampNTZType,
VariantType)
- .map(t => t.typeName -> t).toMap
+ .map(t => t.typeName -> t)
+ .toMap
}
/** Given the string representation of a type, return its DataType */
@@ -191,11 +210,12 @@ object DataType {
// For backwards compatibility, previously the type name of NullType is "null"
case "null" => NullType
case "timestamp_ltz" => TimestampType
- case other => otherTypes.getOrElse(
- other,
- throw new SparkIllegalArgumentException(
- errorClass = "INVALID_JSON_DATA_TYPE",
- messageParameters = Map("invalidType" -> name)))
+ case other =>
+ otherTypes.getOrElse(
+ other,
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_JSON_DATA_TYPE",
+ messageParameters = Map("invalidType" -> name)))
}
}
@@ -220,56 +240,55 @@ object DataType {
}
case JSortedObject(
- ("containsNull", JBool(n)),
- ("elementType", t: JValue),
- ("type", JString("array"))) =>
+ ("containsNull", JBool(n)),
+ ("elementType", t: JValue),
+ ("type", JString("array"))) =>
assertValidTypeForCollations(fieldPath, "array", collationsMap)
val elementType = parseDataType(t, appendFieldToPath(fieldPath, "element"), collationsMap)
ArrayType(elementType, n)
case JSortedObject(
- ("keyType", k: JValue),
- ("type", JString("map")),
- ("valueContainsNull", JBool(n)),
- ("valueType", v: JValue)) =>
+ ("keyType", k: JValue),
+ ("type", JString("map")),
+ ("valueContainsNull", JBool(n)),
+ ("valueType", v: JValue)) =>
assertValidTypeForCollations(fieldPath, "map", collationsMap)
val keyType = parseDataType(k, appendFieldToPath(fieldPath, "key"), collationsMap)
val valueType = parseDataType(v, appendFieldToPath(fieldPath, "value"), collationsMap)
MapType(keyType, valueType, n)
- case JSortedObject(
- ("fields", JArray(fields)),
- ("type", JString("struct"))) =>
+ case JSortedObject(("fields", JArray(fields)), ("type", JString("struct"))) =>
assertValidTypeForCollations(fieldPath, "struct", collationsMap)
StructType(fields.map(parseStructField))
// Scala/Java UDT
case JSortedObject(
- ("class", JString(udtClass)),
- ("pyClass", _),
- ("sqlType", _),
- ("type", JString("udt"))) =>
+ ("class", JString(udtClass)),
+ ("pyClass", _),
+ ("sqlType", _),
+ ("type", JString("udt"))) =>
SparkClassUtils.classForName[UserDefinedType[_]](udtClass).getConstructor().newInstance()
// Python UDT
case JSortedObject(
- ("pyClass", JString(pyClass)),
- ("serializedClass", JString(serialized)),
- ("sqlType", v: JValue),
- ("type", JString("udt"))) =>
- new PythonUserDefinedType(parseDataType(v), pyClass, serialized)
-
- case other => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_JSON_DATA_TYPE",
- messageParameters = Map("invalidType" -> compact(render(other))))
+ ("pyClass", JString(pyClass)),
+ ("serializedClass", JString(serialized)),
+ ("sqlType", v: JValue),
+ ("type", JString("udt"))) =>
+ new PythonUserDefinedType(parseDataType(v), pyClass, serialized)
+
+ case other =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_JSON_DATA_TYPE",
+ messageParameters = Map("invalidType" -> compact(render(other))))
}
private def parseStructField(json: JValue): StructField = json match {
case JSortedObject(
- ("metadata", JObject(metadataFields)),
- ("name", JString(name)),
- ("nullable", JBool(nullable)),
- ("type", dataType: JValue)) =>
+ ("metadata", JObject(metadataFields)),
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
val collationsMap = getCollationsMap(metadataFields)
val metadataWithoutCollations =
JObject(metadataFields.filterNot(_._1 == COLLATIONS_METADATA_KEY))
@@ -280,18 +299,17 @@ object DataType {
Metadata.fromJObject(metadataWithoutCollations))
// Support reading schema when 'metadata' is missing.
case JSortedObject(
- ("name", JString(name)),
- ("nullable", JBool(nullable)),
- ("type", dataType: JValue)) =>
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
StructField(name, parseDataType(dataType), nullable)
// Support reading schema when 'nullable' is missing.
- case JSortedObject(
- ("name", JString(name)),
- ("type", dataType: JValue)) =>
+ case JSortedObject(("name", JString(name)), ("type", dataType: JValue)) =>
StructField(name, parseDataType(dataType))
- case other => throw new SparkIllegalArgumentException(
- errorClass = "INVALID_JSON_DATA_TYPE",
- messageParameters = Map("invalidType" -> compact(render(other))))
+ case other =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "INVALID_JSON_DATA_TYPE",
+ messageParameters = Map("invalidType" -> compact(render(other))))
}
private def assertValidTypeForCollations(
@@ -319,13 +337,12 @@ object DataType {
val collationsJsonOpt = metadataFields.find(_._1 == COLLATIONS_METADATA_KEY).map(_._2)
collationsJsonOpt match {
case Some(JObject(fields)) =>
- fields.collect {
- case (fieldPath, JString(collation)) =>
- collation.split("\\.", 2) match {
- case Array(provider: String, collationName: String) =>
- CollationFactory.assertValidProvider(provider)
- fieldPath -> collationName
- }
+ fields.collect { case (fieldPath, JString(collation)) =>
+ collation.split("\\.", 2) match {
+ case Array(provider: String, collationName: String) =>
+ CollationFactory.assertValidProvider(provider)
+ fieldPath -> collationName
+ }
}.toMap
case _ => Map.empty
@@ -356,15 +373,15 @@ object DataType {
* Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType.
*
* Compatible nullability is defined as follows:
- * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
- * if and only if `to.containsNull` is true, or both of `from.containsNull` and
- * `to.containsNull` are false.
- * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
- * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
- * `to.valueContainsNull` are false.
- * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
- * if and only if for all every pair of fields, `to.nullable` is true, or both
- * of `fromField.nullable` and `toField.nullable` are false.
+ * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` if and
+ * only if `to.containsNull` is true, or both of `from.containsNull` and `to.containsNull`
+ * are false.
+ * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` if and
+ * only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
+ * `to.valueContainsNull` are false.
+ * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` if and
+ * only if for all every pair of fields, `to.nullable` is true, or both of
+ * `fromField.nullable` and `toField.nullable` are false.
*/
private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
equalsIgnoreCompatibleNullability(from, to, ignoreName = false)
@@ -375,15 +392,15 @@ object DataType {
* also the field name. It compares based on the position.
*
* Compatible nullability is defined as follows:
- * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
- * if and only if `to.containsNull` is true, or both of `from.containsNull` and
- * `to.containsNull` are false.
- * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
- * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
- * `to.valueContainsNull` are false.
- * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
- * if and only if for all every pair of fields, `to.nullable` is true, or both
- * of `fromField.nullable` and `toField.nullable` are false.
+ * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` if and
+ * only if `to.containsNull` is true, or both of `from.containsNull` and `to.containsNull`
+ * are false.
+ * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` if and
+ * only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
+ * `to.valueContainsNull` are false.
+ * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` if and
+ * only if for all every pair of fields, `to.nullable` is true, or both of
+ * `fromField.nullable` and `toField.nullable` are false.
*/
private[sql] def equalsIgnoreNameAndCompatibleNullability(
from: DataType,
@@ -401,16 +418,16 @@ object DataType {
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
(tn || !fn) &&
- equalsIgnoreCompatibleNullability(fromKey, toKey, ignoreName) &&
- equalsIgnoreCompatibleNullability(fromValue, toValue, ignoreName)
+ equalsIgnoreCompatibleNullability(fromKey, toKey, ignoreName) &&
+ equalsIgnoreCompatibleNullability(fromValue, toValue, ignoreName)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
- fromFields.zip(toFields).forall { case (fromField, toField) =>
- (ignoreName || fromField.name == toField.name) &&
- (toField.nullable || !fromField.nullable) &&
- equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType, ignoreName)
- }
+ fromFields.zip(toFields).forall { case (fromField, toField) =>
+ (ignoreName || fromField.name == toField.name) &&
+ (toField.nullable || !fromField.nullable) &&
+ equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType, ignoreName)
+ }
case (fromDataType, toDataType) => fromDataType == toDataType
}
@@ -420,42 +437,42 @@ object DataType {
* Check if `from` is equal to `to` type except for collations, which are checked to be
* compatible so that data of type `from` can be interpreted as of type `to`.
*/
- private[sql] def equalsIgnoreCompatibleCollation(
- from: DataType,
- to: DataType): Boolean = {
+ private[sql] def equalsIgnoreCompatibleCollation(from: DataType, to: DataType): Boolean = {
(from, to) match {
// String types with possibly different collations are compatible.
case (_: StringType, _: StringType) => true
case (ArrayType(fromElement, fromContainsNull), ArrayType(toElement, toContainsNull)) =>
(fromContainsNull == toContainsNull) &&
- equalsIgnoreCompatibleCollation(fromElement, toElement)
+ equalsIgnoreCompatibleCollation(fromElement, toElement)
- case (MapType(fromKey, fromValue, fromContainsNull),
- MapType(toKey, toValue, toContainsNull)) =>
+ case (
+ MapType(fromKey, fromValue, fromContainsNull),
+ MapType(toKey, toValue, toContainsNull)) =>
fromContainsNull == toContainsNull &&
- // Map keys cannot change collation.
- fromKey == toKey &&
- equalsIgnoreCompatibleCollation(fromValue, toValue)
+ // Map keys cannot change collation.
+ fromKey == toKey &&
+ equalsIgnoreCompatibleCollation(fromValue, toValue)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
- fromFields.zip(toFields).forall { case (fromField, toField) =>
- fromField.name == toField.name &&
- fromField.nullable == toField.nullable &&
- fromField.metadata == toField.metadata &&
- equalsIgnoreCompatibleCollation(fromField.dataType, toField.dataType)
- }
+ fromFields.zip(toFields).forall { case (fromField, toField) =>
+ fromField.name == toField.name &&
+ fromField.nullable == toField.nullable &&
+ fromField.metadata == toField.metadata &&
+ equalsIgnoreCompatibleCollation(fromField.dataType, toField.dataType)
+ }
case (fromDataType, toDataType) => fromDataType == toDataType
}
}
/**
- * Returns true if the two data types share the same "shape", i.e. the types
- * are the same, but the field names don't need to be the same.
+ * Returns true if the two data types share the same "shape", i.e. the types are the same, but
+ * the field names don't need to be the same.
*
- * @param ignoreNullability whether to ignore nullability when comparing the types
+ * @param ignoreNullability
+ * whether to ignore nullability when comparing the types
*/
def equalsStructurally(
from: DataType,
@@ -464,20 +481,21 @@ object DataType {
(from, to) match {
case (left: ArrayType, right: ArrayType) =>
equalsStructurally(left.elementType, right.elementType, ignoreNullability) &&
- (ignoreNullability || left.containsNull == right.containsNull)
+ (ignoreNullability || left.containsNull == right.containsNull)
case (left: MapType, right: MapType) =>
equalsStructurally(left.keyType, right.keyType, ignoreNullability) &&
- equalsStructurally(left.valueType, right.valueType, ignoreNullability) &&
- (ignoreNullability || left.valueContainsNull == right.valueContainsNull)
+ equalsStructurally(left.valueType, right.valueType, ignoreNullability) &&
+ (ignoreNullability || left.valueContainsNull == right.valueContainsNull)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
- fromFields.zip(toFields)
- .forall { case (l, r) =>
- equalsStructurally(l.dataType, r.dataType, ignoreNullability) &&
- (ignoreNullability || l.nullable == r.nullable)
- }
+ fromFields
+ .zip(toFields)
+ .forall { case (l, r) =>
+ equalsStructurally(l.dataType, r.dataType, ignoreNullability) &&
+ (ignoreNullability || l.nullable == r.nullable)
+ }
case (fromDataType, toDataType) => fromDataType == toDataType
}
@@ -496,14 +514,15 @@ object DataType {
case (left: MapType, right: MapType) =>
equalsStructurallyByName(left.keyType, right.keyType, resolver) &&
- equalsStructurallyByName(left.valueType, right.valueType, resolver)
+ equalsStructurallyByName(left.valueType, right.valueType, resolver)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
- fromFields.zip(toFields)
- .forall { case (l, r) =>
- resolver(l.name, r.name) && equalsStructurallyByName(l.dataType, r.dataType, resolver)
- }
+ fromFields
+ .zip(toFields)
+ .forall { case (l, r) =>
+ resolver(l.name, r.name) && equalsStructurallyByName(l.dataType, r.dataType, resolver)
+ }
case _ => true
}
@@ -518,12 +537,12 @@ object DataType {
equalsIgnoreNullability(leftElementType, rightElementType)
case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
equalsIgnoreNullability(leftKeyType, rightKeyType) &&
- equalsIgnoreNullability(leftValueType, rightValueType)
+ equalsIgnoreNullability(leftValueType, rightValueType)
case (StructType(leftFields), StructType(rightFields)) =>
leftFields.length == rightFields.length &&
- leftFields.zip(rightFields).forall { case (l, r) =>
- l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
- }
+ leftFields.zip(rightFields).forall { case (l, r) =>
+ l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
+ }
case (l, r) => l == r
}
}
@@ -539,14 +558,14 @@ object DataType {
case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
equalsIgnoreCaseAndNullability(fromKey, toKey) &&
- equalsIgnoreCaseAndNullability(fromValue, toValue)
+ equalsIgnoreCaseAndNullability(fromValue, toValue)
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
- fromFields.zip(toFields).forall { case (l, r) =>
- l.name.equalsIgnoreCase(r.name) &&
- equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
- }
+ fromFields.zip(toFields).forall { case (l, r) =>
+ l.name.equalsIgnoreCase(r.name) &&
+ equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
+ }
case (fromDataType, toDataType) => fromDataType == toDataType
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala
index d37ebbcdad727..402c4c0d95298 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DateType.scala
@@ -20,14 +20,15 @@ package org.apache.spark.sql.types
import org.apache.spark.annotation.Stable
/**
- * The date type represents a valid date in the proleptic Gregorian calendar.
- * Valid range is [0001-01-01, 9999-12-31].
+ * The date type represents a valid date in the proleptic Gregorian calendar. Valid range is
+ * [0001-01-01, 9999-12-31].
*
* Please use the singleton `DataTypes.DateType` to refer the type.
* @since 1.3.0
*/
@Stable
-class DateType private() extends DatetimeType {
+class DateType private () extends DatetimeType {
+
/**
* The default size of a value of the DateType is 4 bytes.
*/
@@ -37,10 +38,10 @@ class DateType private() extends DatetimeType {
}
/**
- * The companion case object and the DateType class is separated so the companion object
- * also subclasses the class. Otherwise, the companion object would be of type "DateType$"
- * in byte code. The DateType class is defined with a private constructor so its companion
- * object is the only possible instantiation.
+ * The companion case object and the DateType class is separated so the companion object also
+ * subclasses the class. Otherwise, the companion object would be of type "DateType$" in byte
+ * code. The DateType class is defined with a private constructor so its companion object is the
+ * only possible instantiation.
*
* @since 1.3.0
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala
index a1d014fa51f1c..90d6d7c29a6ba 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DayTimeIntervalType.scala
@@ -22,8 +22,8 @@ import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.types.DayTimeIntervalType.fieldToString
/**
- * The type represents day-time intervals of the SQL standard. A day-time interval is made up
- * of a contiguous subset of the following fields:
+ * The type represents day-time intervals of the SQL standard. A day-time interval is made up of a
+ * contiguous subset of the following fields:
* - SECOND, seconds within minutes and possibly fractions of a second [0..59.999999],
* - MINUTE, minutes within hours [0..59],
* - HOUR, hours within days [0..23],
@@ -31,18 +31,21 @@ import org.apache.spark.sql.types.DayTimeIntervalType.fieldToString
*
* `DayTimeIntervalType` represents positive as well as negative day-time intervals.
*
- * @param startField The leftmost field which the type comprises of. Valid values:
- * 0 (DAY), 1 (HOUR), 2 (MINUTE), 3 (SECOND).
- * @param endField The rightmost field which the type comprises of. Valid values:
- * 0 (DAY), 1 (HOUR), 2 (MINUTE), 3 (SECOND).
+ * @param startField
+ * The leftmost field which the type comprises of. Valid values: 0 (DAY), 1 (HOUR), 2 (MINUTE),
+ * 3 (SECOND).
+ * @param endField
+ * The rightmost field which the type comprises of. Valid values: 0 (DAY), 1 (HOUR), 2 (MINUTE),
+ * 3 (SECOND).
*
* @since 3.2.0
*/
@Unstable
case class DayTimeIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType {
+
/**
- * The day-time interval type has constant precision. A value of the type always occupies 8 bytes.
- * The DAY field is constrained by the upper bound 106751991 to fit to `Long`.
+ * The day-time interval type has constant precision. A value of the type always occupies 8
+ * bytes. The DAY field is constrained by the upper bound 106751991 to fit to `Long`.
*/
override def defaultSize: Int = 8
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 6de8570b1422f..bd94c386ab533 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -31,9 +31,9 @@ import org.apache.spark.unsafe.types.UTF8String
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
*
* The semantics of the fields are as follows:
- * - _precision and _scale represent the SQL precision and scale we are looking for
- * - If decimalVal is set, it represents the whole decimal value
- * - Otherwise, the decimal value is longVal / (10 ** _scale)
+ * - _precision and _scale represent the SQL precision and scale we are looking for
+ * - If decimalVal is set, it represents the whole decimal value
+ * - Otherwise, the decimal value is longVal / (10 ** _scale)
*
* Note, for values between -1.0 and 1.0, precision digits are only counted after dot.
*/
@@ -88,22 +88,22 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
/**
- * Set this Decimal to the given unscaled Long, with a given precision and scale,
- * and return it, or return null if it cannot be set due to overflow.
+ * Set this Decimal to the given unscaled Long, with a given precision and scale, and return it,
+ * or return null if it cannot be set due to overflow.
*/
def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = {
DecimalType.checkNegativeScale(scale)
if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) {
// We can't represent this compactly as a long without risking overflow
if (precision < 19) {
- return null // Requested precision is too low to represent this value
+ return null // Requested precision is too low to represent this value
}
this.decimalVal = BigDecimal(unscaled, scale)
this.longVal = 0L
} else {
val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
if (unscaled <= -p || unscaled >= p) {
- return null // Requested precision is too low to represent this value
+ return null // Requested precision is too low to represent this value
}
this.decimalVal = null
this.longVal = unscaled
@@ -126,7 +126,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
"roundedValue" -> decimalVal.toString,
"originalValue" -> decimal.toString,
"precision" -> precision.toString,
- "scale" -> scale.toString), Array.empty)
+ "scale" -> scale.toString),
+ Array.empty)
}
this.longVal = 0L
this._precision = precision
@@ -160,8 +161,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
/**
- * If the value is not in the range of long, convert it to BigDecimal and
- * the precision and scale are based on the converted value.
+ * If the value is not in the range of long, convert it to BigDecimal and the precision and
+ * scale are based on the converted value.
*
* This code avoids BigDecimal object allocation as possible to improve runtime efficiency
*/
@@ -262,37 +263,47 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def toByte: Byte = toLong.toByte
/**
- * @return the Byte value that is equal to the rounded decimal.
- * @throws ArithmeticException if the decimal is too big to fit in Byte type.
+ * @return
+ * the Byte value that is equal to the rounded decimal.
+ * @throws ArithmeticException
+ * if the decimal is too big to fit in Byte type.
*/
private[sql] def roundToByte(): Byte =
- roundToNumeric[Byte](ByteType, Byte.MaxValue, Byte.MinValue) (_.toByte) (_.toByte)
+ roundToNumeric[Byte](ByteType, Byte.MaxValue, Byte.MinValue)(_.toByte)(_.toByte)
/**
- * @return the Short value that is equal to the rounded decimal.
- * @throws ArithmeticException if the decimal is too big to fit in Short type.
+ * @return
+ * the Short value that is equal to the rounded decimal.
+ * @throws ArithmeticException
+ * if the decimal is too big to fit in Short type.
*/
private[sql] def roundToShort(): Short =
- roundToNumeric[Short](ShortType, Short.MaxValue, Short.MinValue) (_.toShort) (_.toShort)
+ roundToNumeric[Short](ShortType, Short.MaxValue, Short.MinValue)(_.toShort)(_.toShort)
/**
- * @return the Int value that is equal to the rounded decimal.
- * @throws ArithmeticException if the decimal too big to fit in Int type.
+ * @return
+ * the Int value that is equal to the rounded decimal.
+ * @throws ArithmeticException
+ * if the decimal too big to fit in Int type.
*/
private[sql] def roundToInt(): Int =
- roundToNumeric[Int](IntegerType, Int.MaxValue, Int.MinValue) (_.toInt) (_.toInt)
+ roundToNumeric[Int](IntegerType, Int.MaxValue, Int.MinValue)(_.toInt)(_.toInt)
private def toSqlValue: String = this.toString + "BD"
- private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int)
- (f1: Long => T) (f2: Double => T): T = {
+ private def roundToNumeric[T <: AnyVal](
+ integralType: IntegralType,
+ maxValue: Int,
+ minValue: Int)(f1: Long => T)(f2: Double => T): T = {
if (decimalVal.eq(null)) {
val numericVal = f1(actualLongVal)
if (actualLongVal == numericVal) {
numericVal
} else {
throw DataTypeErrors.castingCauseOverflowError(
- toSqlValue, DecimalType(this.precision, this.scale), integralType)
+ toSqlValue,
+ DecimalType(this.precision, this.scale),
+ integralType)
}
} else {
val doubleVal = decimalVal.toDouble
@@ -300,14 +311,18 @@ final class Decimal extends Ordered[Decimal] with Serializable {
f2(doubleVal)
} else {
throw DataTypeErrors.castingCauseOverflowError(
- toSqlValue, DecimalType(this.precision, this.scale), integralType)
+ toSqlValue,
+ DecimalType(this.precision, this.scale),
+ integralType)
}
}
}
/**
- * @return the Long value that is equal to the rounded decimal.
- * @throws ArithmeticException if the decimal too big to fit in Long type.
+ * @return
+ * the Long value that is equal to the rounded decimal.
+ * @throws ArithmeticException
+ * if the decimal too big to fit in Long type.
*/
private[sql] def roundToLong(): Long = {
if (decimalVal.eq(null)) {
@@ -321,7 +336,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
} catch {
case _: ArithmeticException =>
throw DataTypeErrors.castingCauseOverflowError(
- toSqlValue, DecimalType(this.precision, this.scale), LongType)
+ toSqlValue,
+ DecimalType(this.precision, this.scale),
+ LongType)
}
}
}
@@ -329,7 +346,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
/**
* Update precision and scale while keeping our value the same, and return true if successful.
*
- * @return true if successful, false if overflow would occur
+ * @return
+ * true if successful, false if overflow would occur
*/
def changePrecision(precision: Int, scale: Int): Boolean = {
changePrecision(precision, scale, ROUND_HALF_UP)
@@ -338,8 +356,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
/**
* Create new `Decimal` with given precision and scale.
*
- * @return a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null
- * is returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown.
+ * @return
+ * a non-null `Decimal` value if successful. Otherwise, if `nullOnOverflow` is true, null is
+ * returned; if `nullOnOverflow` is false, an `ArithmeticException` is thrown.
*/
private[sql] def toPrecision(
precision: Int,
@@ -354,8 +373,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (nullOnOverflow) {
null
} else {
- throw DataTypeErrors.cannotChangeDecimalPrecisionError(
- this, precision, scale, context)
+ throw DataTypeErrors.cannotChangeDecimalPrecisionError(this, precision, scale, context)
}
}
}
@@ -363,7 +381,8 @@ final class Decimal extends Ordered[Decimal] with Serializable {
/**
* Update precision and scale while keeping our value the same, and return true if successful.
*
- * @return true if successful, false if overflow would occur
+ * @return
+ * true if successful, false if overflow would occur
*/
private[sql] def changePrecision(
precision: Int,
@@ -482,7 +501,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
// ------------------------------------------------------------------------
// e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
// e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
- def + (that: Decimal): Decimal = {
+ def +(that: Decimal): Decimal = {
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
Decimal(longVal + that.longVal, Math.max(precision, that.precision) + 1, scale)
} else {
@@ -490,7 +509,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
- def - (that: Decimal): Decimal = {
+ def -(that: Decimal): Decimal = {
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
Decimal(longVal - that.longVal, Math.max(precision, that.precision) + 1, scale)
} else {
@@ -499,14 +518,19 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
// TypeCoercion will take care of the precision, scale of result
- def * (that: Decimal): Decimal =
+ def *(that: Decimal): Decimal =
Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT))
- def / (that: Decimal): Decimal =
- if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal,
- DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode))
+ def /(that: Decimal): Decimal =
+ if (that.isZero) {
+ null
+ } else {
+ Decimal(
+ toJavaBigDecimal
+ .divide(that.toJavaBigDecimal, DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode))
+ }
- def % (that: Decimal): Decimal =
+ def %(that: Decimal): Decimal =
if (that.isZero) null
else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT))
@@ -526,12 +550,14 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this
- def floor: Decimal = if (scale == 0) this else {
+ def floor: Decimal = if (scale == 0) this
+ else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
toPrecision(newPrecision, 0, ROUND_FLOOR, nullOnOverflow = false)
}
- def ceil: Decimal = if (scale == 0) this else {
+ def ceil: Decimal = if (scale == 0) this
+ else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
toPrecision(newPrecision, 0, ROUND_CEILING, nullOnOverflow = false)
}
@@ -612,7 +638,7 @@ object Decimal {
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
// For example: Decimal("6.0790316E+25569151")
if (numDigitsInIntegralPart(bigDecimal) > DecimalType.MAX_PRECISION &&
- !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) {
+ !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) {
null
} else {
Decimal(bigDecimal)
@@ -632,7 +658,7 @@ object Decimal {
// We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow.
// For example: Decimal("6.0790316E+25569151")
if (numDigitsInIntegralPart(bigDecimal) > DecimalType.MAX_PRECISION &&
- !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) {
+ !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) {
throw DataTypeErrors.outOfDecimalTypeRangeError(str)
} else {
Decimal(bigDecimal)
@@ -657,16 +683,18 @@ object Decimal {
// Max precision of a decimal value stored in `numBytes` bytes
def maxPrecisionForBytes(numBytes: Int): Int = {
- Math.round( // convert double to long
- Math.floor(Math.log10( // number of base-10 digits
- Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes
+ Math
+ .round( // convert double to long
+ Math.floor(Math.log10( // number of base-10 digits
+ Math.pow(2, 8 * numBytes - 1) - 1))
+ ) // max value stored in numBytes
.asInstanceOf[Int]
}
// Returns the minimum number of bytes needed to store a decimal with a given `precision`.
lazy val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision)
- private def computeMinBytesForPrecision(precision : Int) : Int = {
+ private def computeMinBytesForPrecision(precision: Int): Int = {
var numBytes = 1
while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
numBytes += 1
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 9de34d0b3bc16..bff483cefda91 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -26,9 +26,8 @@ import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.internal.SqlApiConf
/**
- * The data type representing `java.math.BigDecimal` values.
- * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number
- * of digits on right side of dot).
+ * The data type representing `java.math.BigDecimal` values. A Decimal that must have fixed
+ * precision (the maximum number of digits) and scale (the number of digits on right side of dot).
*
* The precision can be up to 38, scale can also be up to 38 (less or equal to precision).
*
@@ -49,7 +48,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
if (precision > DecimalType.MAX_PRECISION) {
throw DataTypeErrors.decimalPrecisionExceedsMaxPrecisionError(
- precision, DecimalType.MAX_PRECISION)
+ precision,
+ DecimalType.MAX_PRECISION)
}
// default constructor for Java
@@ -63,8 +63,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
override def sql: String = typeName.toUpperCase(Locale.ROOT)
/**
- * Returns whether this DecimalType is wider than `other`. If yes, it means `other`
- * can be casted into `this` safely without losing any precision or range.
+ * Returns whether this DecimalType is wider than `other`. If yes, it means `other` can be
+ * casted into `this` safely without losing any precision or range.
*/
private[sql] def isWiderThan(other: DataType): Boolean = isWiderThanInternal(other)
@@ -78,8 +78,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
}
/**
- * Returns whether this DecimalType is tighter than `other`. If yes, it means `this`
- * can be casted into `other` safely without losing any precision or range.
+ * Returns whether this DecimalType is tighter than `other`. If yes, it means `this` can be
+ * casted into `other` safely without losing any precision or range.
*/
private[sql] def isTighterThan(other: DataType): Boolean = other match {
case dt: DecimalType =>
@@ -94,8 +94,8 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
}
/**
- * The default size of a value of the DecimalType is 8 bytes when precision is at most 18,
- * and 16 bytes otherwise.
+ * The default size of a value of the DecimalType is 8 bytes when precision is at most 18, and
+ * 16 bytes otherwise.
*/
override def defaultSize: Int = if (precision <= Decimal.MAX_LONG_DIGITS) 8 else 16
@@ -104,7 +104,6 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
private[spark] override def asNullable: DecimalType = this
}
-
/**
* Extra factory methods and pattern matchers for Decimals.
*
@@ -167,10 +166,11 @@ object DecimalType extends AbstractDataType {
/**
* Scale adjustment implementation is based on Hive's one, which is itself inspired to
* SQLServer's one. In particular, when a result precision is greater than
- * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a
+ * {@link #MAX_PRECISION} , the corresponding scale is reduced to prevent the integral part of a
* result from being truncated.
*
- * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true.
+ * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to
+ * true.
*/
private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
// Assumptions:
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
index bc0ed725cf266..873f0c237c6c4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
@@ -27,7 +27,8 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class DoubleType private() extends FractionalType {
+class DoubleType private () extends FractionalType {
+
/**
* The default size of a value of the DoubleType is 8 bytes.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala
index 8b54f830d48a3..df4b03cd42bd4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/FloatType.scala
@@ -27,7 +27,8 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class FloatType private() extends FractionalType {
+class FloatType private () extends FractionalType {
+
/**
* The default size of a value of the FloatType is 4 bytes.
*/
@@ -36,7 +37,6 @@ class FloatType private() extends FractionalType {
private[spark] override def asNullable: FloatType = this
}
-
/**
* @since 1.3.0
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
index b26a555c9b572..dc4727cb1215b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
@@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class IntegerType private() extends IntegralType {
+class IntegerType private () extends IntegralType {
+
/**
* The default size of a value of the IntegerType is 4 bytes.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala
index 87ebacfe9ce88..f65c4c70acd27 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/LongType.scala
@@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class LongType private() extends IntegralType {
+class LongType private () extends IntegralType {
+
/**
* The default size of a value of the LongType is 8 bytes.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala
index dba870466fc1c..1dfb9aaf9e29b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -28,15 +28,16 @@ import org.apache.spark.sql.catalyst.util.StringConcat
*
* Please use `DataTypes.createMapType()` to create a specific instance.
*
- * @param keyType The data type of map keys.
- * @param valueType The data type of map values.
- * @param valueContainsNull Indicates if map values have `null` values.
+ * @param keyType
+ * The data type of map keys.
+ * @param valueType
+ * The data type of map values.
+ * @param valueContainsNull
+ * Indicates if map values have `null` values.
*/
@Stable
-case class MapType(
- keyType: DataType,
- valueType: DataType,
- valueContainsNull: Boolean) extends DataType {
+case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean)
+ extends DataType {
/** No-arg constructor for kryo. */
def this() = this(null, null, false)
@@ -48,8 +49,9 @@ case class MapType(
if (maxDepth > 0) {
stringConcat.append(s"$prefix-- key: ${keyType.typeName}\n")
DataType.buildFormattedString(keyType, s"$prefix |", stringConcat, maxDepth)
- stringConcat.append(s"$prefix-- value: ${valueType.typeName} " +
- s"(valueContainsNull = $valueContainsNull)\n")
+ stringConcat.append(
+ s"$prefix-- value: ${valueType.typeName} " +
+ s"(valueContainsNull = $valueContainsNull)\n")
DataType.buildFormattedString(valueType, s"$prefix |", stringConcat, maxDepth)
}
}
@@ -61,9 +63,9 @@ case class MapType(
("valueContainsNull" -> valueContainsNull)
/**
- * The default size of a value of the MapType is
- * (the default size of the key type + the default size of the value type).
- * We assume that there is only 1 element on average in a map. See SPARK-18853.
+ * The default size of a value of the MapType is (the default size of the key type + the default
+ * size of the value type). We assume that there is only 1 element on average in a map. See
+ * SPARK-18853.
*/
override def defaultSize: Int = 1 * (keyType.defaultSize + valueType.defaultSize)
@@ -77,8 +79,8 @@ case class MapType(
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
/**
- * Returns the same data type but set all nullability fields are true
- * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ * Returns the same data type but set all nullability fields are true (`StructField.nullable`,
+ * `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*
* @since 4.0.0
*/
@@ -104,8 +106,8 @@ object MapType extends AbstractDataType {
override private[sql] def simpleString: String = "map"
/**
- * Construct a [[MapType]] object with the given key type and value type.
- * The `valueContainsNull` is true.
+ * Construct a [[MapType]] object with the given key type and value type. The
+ * `valueContainsNull` is true.
*/
def apply(keyType: DataType, valueType: DataType): MapType =
MapType(keyType: DataType, valueType: DataType, valueContainsNull = true)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala
index 70e03905d4b05..05f91b5ba2313 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala
@@ -26,7 +26,6 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.util.ArrayImplicits._
-
/**
* Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean,
* Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and
@@ -35,13 +34,14 @@ import org.apache.spark.util.ArrayImplicits._
* The default constructor is private. User should use either [[MetadataBuilder]] or
* `Metadata.fromJson()` to create Metadata instances.
*
- * @param map an immutable map that stores the data
+ * @param map
+ * an immutable map that stores the data
*
* @since 1.3.0
*/
@Stable
sealed class Metadata private[types] (private[types] val map: Map[String, Any])
- extends Serializable {
+ extends Serializable {
/** No-arg constructor for kryo. */
protected def this() = this(null)
@@ -173,7 +173,8 @@ object Metadata {
builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray)
case _: JObject =>
builder.putMetadataArray(
- key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray)
+ key,
+ value.asInstanceOf[List[JObject]].map(fromJObject).toArray)
case other =>
throw DataTypeErrors.unsupportedArrayTypeError(other.getClass)
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala
index d211fac70c641..4e7fd3a00a8af 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/NullType.scala
@@ -25,7 +25,7 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class NullType private() extends DataType {
+class NullType private () extends DataType {
// The companion object and this class is separated so the companion object also subclasses
// this type. Otherwise, the companion object would be of type "NullType$" in byte code.
// Defined with a private constructor so the companion object is the only possible instantiation.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala
index 66696793e6279..c3b6bc75facd3 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/ShortType.scala
@@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable
* @since 1.3.0
*/
@Stable
-class ShortType private() extends IntegralType {
+class ShortType private () extends IntegralType {
+
/**
* The default size of a value of the ShortType is 2 bytes.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
index df77f091f41f4..c2dd6cec7ba74 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -26,15 +26,17 @@ import org.apache.spark.sql.catalyst.util.CollationFactory
* The data type representing `String` values. Please use the singleton `DataTypes.StringType`.
*
* @since 1.3.0
- * @param collationId The id of collation for this StringType.
+ * @param collationId
+ * The id of collation for this StringType.
*/
@Stable
-class StringType private(val collationId: Int) extends AtomicType with Serializable {
+class StringType private (val collationId: Int) extends AtomicType with Serializable {
+
/**
- * Support for Binary Equality implies that strings are considered equal only if
- * they are byte for byte equal. E.g. all accent or case-insensitive collations are considered
- * non-binary. If this field is true, byte level operations can be used against this datatype
- * (e.g. for equality and hashing).
+ * Support for Binary Equality implies that strings are considered equal only if they are byte
+ * for byte equal. E.g. all accent or case-insensitive collations are considered non-binary. If
+ * this field is true, byte level operations can be used against this datatype (e.g. for
+ * equality and hashing).
*/
private[sql] def supportsBinaryEquality: Boolean =
CollationFactory.fetchCollation(collationId).supportsBinaryEquality
@@ -42,6 +44,9 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
private[sql] def supportsLowercaseEquality: Boolean =
CollationFactory.fetchCollation(collationId).supportsLowercaseEquality
+ private[sql] def isNonCSAI: Boolean =
+ !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)
+
private[sql] def isUTF8BinaryCollation: Boolean =
collationId == CollationFactory.UTF8_BINARY_COLLATION_ID
@@ -49,18 +54,18 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
collationId == CollationFactory.UTF8_LCASE_COLLATION_ID
/**
- * Support for Binary Ordering implies that strings are considered equal only
- * if they are byte for byte equal. E.g. all accent or case-insensitive collations are
- * considered non-binary. Also their ordering does not require calls to ICU library, as
- * it follows spark internal implementation. If this field is true, byte level operations
- * can be used against this datatype (e.g. for equality, hashing and ordering).
+ * Support for Binary Ordering implies that strings are considered equal only if they are byte
+ * for byte equal. E.g. all accent or case-insensitive collations are considered non-binary.
+ * Also their ordering does not require calls to ICU library, as it follows spark internal
+ * implementation. If this field is true, byte level operations can be used against this
+ * datatype (e.g. for equality, hashing and ordering).
*/
private[sql] def supportsBinaryOrdering: Boolean =
CollationFactory.fetchCollation(collationId).supportsBinaryOrdering
/**
- * Type name that is shown to the customer.
- * If this is an UTF8_BINARY collation output is `string` due to backwards compatibility.
+ * Type name that is shown to the customer. If this is an UTF8_BINARY collation output is
+ * `string` due to backwards compatibility.
*/
override def typeName: String =
if (isUTF8BinaryCollation) "string"
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
index 3ff96fea9ee04..d4e590629921c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
@@ -31,11 +31,15 @@ import org.apache.spark.util.SparkSchemaUtils
/**
* A field inside a StructType.
- * @param name The name of this field.
- * @param dataType The data type of this field.
- * @param nullable Indicates if values of this field can be `null` values.
- * @param metadata The metadata of this field. The metadata should be preserved during
- * transformation if the content of the column is not modified, e.g, in selection.
+ * @param name
+ * The name of this field.
+ * @param dataType
+ * The data type of this field.
+ * @param nullable
+ * Indicates if values of this field can be `null` values.
+ * @param metadata
+ * The metadata of this field. The metadata should be preserved during transformation if the
+ * content of the column is not modified, e.g, in selection.
*
* @since 1.3.0
*/
@@ -54,8 +58,9 @@ case class StructField(
stringConcat: StringConcat,
maxDepth: Int): Unit = {
if (maxDepth > 0) {
- stringConcat.append(s"$prefix-- ${SparkSchemaUtils.escapeMetaCharacters(name)}: " +
- s"${dataType.typeName} (nullable = $nullable)\n")
+ stringConcat.append(
+ s"$prefix-- ${SparkSchemaUtils.escapeMetaCharacters(name)}: " +
+ s"${dataType.typeName} (nullable = $nullable)\n")
DataType.buildFormattedString(dataType, s"$prefix |", stringConcat, maxDepth)
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 3e637b5110122..4ef1cf400b80e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -41,10 +41,10 @@ import org.apache.spark.util.SparkCollectionUtils
* {{{
* StructType(fields: Seq[StructField])
* }}}
- * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names.
- * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned.
- * If a provided name does not have a matching field, it will be ignored. For the case
- * of extracting a single [[StructField]], a `null` will be returned.
+ * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. If
+ * multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. If a
+ * provided name does not have a matching field, it will be ignored. For the case of extracting a
+ * single [[StructField]], a `null` will be returned.
*
* Scala Example:
* {{{
@@ -126,8 +126,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override def equals(that: Any): Boolean = {
that match {
case StructType(otherFields) =>
- java.util.Arrays.equals(
- fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]])
+ java.util.Arrays
+ .equals(fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]])
case _ => false
}
}
@@ -146,7 +146,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* .add(StructField("a", IntegerType, true))
* .add(StructField("b", LongType, false))
* .add(StructField("c", StringType, true))
- *}}}
+ * }}}
*/
def add(field: StructField): StructType = {
StructType(fields :+ field)
@@ -155,10 +155,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
/**
* Creates a new [[StructType]] by adding a new nullable field with no metadata.
*
- * val struct = (new StructType)
- * .add("a", IntegerType)
- * .add("b", LongType)
- * .add("c", StringType)
+ * val struct = (new StructType) .add("a", IntegerType) .add("b", LongType) .add("c",
+ * StringType)
*/
def add(name: String, dataType: DataType): StructType = {
StructType(fields :+ StructField(name, dataType, nullable = true, Metadata.empty))
@@ -167,10 +165,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
/**
* Creates a new [[StructType]] by adding a new field with no metadata.
*
- * val struct = (new StructType)
- * .add("a", IntegerType, true)
- * .add("b", LongType, false)
- * .add("c", StringType, true)
+ * val struct = (new StructType) .add("a", IntegerType, true) .add("b", LongType, false)
+ * .add("c", StringType, true)
*/
def add(name: String, dataType: DataType, nullable: Boolean): StructType = {
StructType(fields :+ StructField(name, dataType, nullable, Metadata.empty))
@@ -185,11 +181,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* .add("c", StringType, true, Metadata.empty)
* }}}
*/
- def add(
- name: String,
- dataType: DataType,
- nullable: Boolean,
- metadata: Metadata): StructType = {
+ def add(name: String, dataType: DataType, nullable: Boolean, metadata: Metadata): StructType = {
StructType(fields :+ StructField(name, dataType, nullable, metadata))
}
@@ -202,11 +194,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* .add("c", StringType, true, "comment3")
* }}}
*/
- def add(
- name: String,
- dataType: DataType,
- nullable: Boolean,
- comment: String): StructType = {
+ def add(name: String, dataType: DataType, nullable: Boolean, comment: String): StructType = {
StructType(fields :+ StructField(name, dataType, nullable).withComment(comment))
}
@@ -226,8 +214,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
}
/**
- * Creates a new [[StructType]] by adding a new field with no metadata where the
- * dataType is specified as a String.
+ * Creates a new [[StructType]] by adding a new field with no metadata where the dataType is
+ * specified as a String.
*
* {{{
* val struct = (new StructType)
@@ -241,8 +229,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
}
/**
- * Creates a new [[StructType]] by adding a new field and specifying metadata where the
- * dataType is specified as a String.
+ * Creates a new [[StructType]] by adding a new field and specifying metadata where the dataType
+ * is specified as a String.
* {{{
* val struct = (new StructType)
* .add("a", "int", true, Metadata.empty)
@@ -250,17 +238,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* .add("c", "string", true, Metadata.empty)
* }}}
*/
- def add(
- name: String,
- dataType: String,
- nullable: Boolean,
- metadata: Metadata): StructType = {
+ def add(name: String, dataType: String, nullable: Boolean, metadata: Metadata): StructType = {
add(name, DataTypeParser.parseDataType(dataType), nullable, metadata)
}
/**
- * Creates a new [[StructType]] by adding a new field and specifying metadata where the
- * dataType is specified as a String.
+ * Creates a new [[StructType]] by adding a new field and specifying metadata where the dataType
+ * is specified as a String.
* {{{
* val struct = (new StructType)
* .add("a", "int", true, "comment1")
@@ -268,21 +252,19 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* .add("c", "string", true, "comment3")
* }}}
*/
- def add(
- name: String,
- dataType: String,
- nullable: Boolean,
- comment: String): StructType = {
+ def add(name: String, dataType: String, nullable: Boolean, comment: String): StructType = {
add(name, DataTypeParser.parseDataType(dataType), nullable, comment)
}
/**
* Extracts the [[StructField]] with the given name.
*
- * @throws IllegalArgumentException if a field with the given name does not exist
+ * @throws IllegalArgumentException
+ * if a field with the given name does not exist
*/
def apply(name: String): StructField = {
- nameToField.getOrElse(name,
+ nameToField.getOrElse(
+ name,
throw new SparkIllegalArgumentException(
errorClass = "FIELD_NOT_FOUND",
messageParameters = immutable.Map(
@@ -294,7 +276,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the
* original order of fields.
*
- * @throws IllegalArgumentException if at least one given field name does not exist
+ * @throws IllegalArgumentException
+ * if at least one given field name does not exist
*/
def apply(names: Set[String]): StructType = {
val nonExistFields = names -- fieldNamesSet
@@ -312,10 +295,12 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
/**
* Returns the index of a given field.
*
- * @throws IllegalArgumentException if a field with the given name does not exist
+ * @throws IllegalArgumentException
+ * if a field with the given name does not exist
*/
def fieldIndex(name: String): Int = {
- nameToIndex.getOrElse(name,
+ nameToIndex.getOrElse(
+ name,
throw new SparkIllegalArgumentException(
errorClass = "FIELD_NOT_FOUND",
messageParameters = immutable.Map(
@@ -354,10 +339,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
} else if (found.isEmpty) {
None
} else {
- findField(
- parent = found.head,
- searchPath = searchPath.tail,
- normalizedPath)
+ findField(parent = found.head, searchPath = searchPath.tail, normalizedPath)
}
}
@@ -433,11 +415,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
override def simpleString: String = {
- val fieldTypes = fields.to(LazyList)
+ val fieldTypes = fields
+ .to(LazyList)
.map(field => s"${field.name}:${field.dataType.simpleString}")
SparkStringUtils.truncatedString(
fieldTypes,
- "struct<", ",", ">",
+ "struct<",
+ ",",
+ ">",
SqlApiConf.get.maxToStringFields)
}
@@ -460,9 +445,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
/**
* Returns a string containing a schema in DDL format. For example, the following value:
- * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))`
- * will be converted to `eventId` INT, `s` STRING.
- * The returned DDL schema can be used in a table creation.
+ * `StructType(Seq(StructField("eventId", IntegerType), StructField("s", StringType)))` will be
+ * converted to `eventId` INT, `s` STRING. The returned DDL schema can be used in a table
+ * creation.
*
* @since 2.4.0
*/
@@ -470,8 +455,8 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private[sql] override def simpleString(maxNumberFields: Int): String = {
val builder = new StringBuilder
- val fieldTypes = fields.take(maxNumberFields).map {
- f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}"
+ val fieldTypes = fields.take(maxNumberFields).map { f =>
+ s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}"
}
builder.append("struct<")
builder.append(fieldTypes.mkString(", "))
@@ -486,31 +471,29 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
}
/**
- * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field
- * B from `that`,
+ * Merges with another schema (`StructType`). For a struct field A from `this` and a struct
+ * field B from `that`,
*
- * 1. If A and B have the same name and data type, they are merged to a field C with the same name
- * and data type. C is nullable if and only if either A or B is nullable.
- * 2. If A doesn't exist in `that`, it's included in the result schema.
- * 3. If B doesn't exist in `this`, it's also included in the result schema.
- * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be
- * thrown.
+ * 1. If A and B have the same name and data type, they are merged to a field C with the same
+ * name and data type. C is nullable if and only if either A or B is nullable. 2. If A
+ * doesn't exist in `that`, it's included in the result schema. 3. If B doesn't exist in
+ * `this`, it's also included in the result schema. 4. Otherwise, `this` and `that` are
+ * considered as conflicting schemas and an exception would be thrown.
*/
private[sql] def merge(that: StructType, caseSensitive: Boolean = true): StructType =
StructType.merge(this, that, caseSensitive).asInstanceOf[StructType]
override private[spark] def asNullable: StructType = {
- val newFields = fields.map {
- case StructField(name, dataType, nullable, metadata) =>
- StructField(name, dataType.asNullable, nullable = true, metadata)
+ val newFields = fields.map { case StructField(name, dataType, nullable, metadata) =>
+ StructField(name, dataType.asNullable, nullable = true, metadata)
}
StructType(newFields)
}
/**
- * Returns the same data type but set all nullability fields are true
- * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ * Returns the same data type but set all nullability fields are true (`StructField.nullable`,
+ * `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*
* @since 4.0.0
*/
@@ -562,7 +545,8 @@ object StructType extends AbstractDataType {
case StructType(fields) =>
val newFields = fields.map { f =>
val mb = new MetadataBuilder()
- f.copy(dataType = removeMetadata(key, f.dataType),
+ f.copy(
+ dataType = removeMetadata(key, f.dataType),
metadata = mb.withMetadata(f.metadata).remove(key).build())
}
StructType(newFields)
@@ -570,39 +554,49 @@ object StructType extends AbstractDataType {
}
/**
- * This leverages `merge` to merge data types for UNION operator by specializing
- * the handling of struct types to follow UNION semantics.
+ * This leverages `merge` to merge data types for UNION operator by specializing the handling of
+ * struct types to follow UNION semantics.
*/
private[sql] def unionLikeMerge(left: DataType, right: DataType): DataType =
- mergeInternal(left, right, (s1: StructType, s2: StructType) => {
- val leftFields = s1.fields
- val rightFields = s2.fields
- require(leftFields.length == rightFields.length, "To merge nullability, " +
- "two structs must have same number of fields.")
-
- val newFields = leftFields.zip(rightFields).map {
- case (leftField, rightField) =>
+ mergeInternal(
+ left,
+ right,
+ (s1: StructType, s2: StructType) => {
+ val leftFields = s1.fields
+ val rightFields = s2.fields
+ require(
+ leftFields.length == rightFields.length,
+ "To merge nullability, " +
+ "two structs must have same number of fields.")
+
+ val newFields = leftFields.zip(rightFields).map { case (leftField, rightField) =>
leftField.copy(
dataType = unionLikeMerge(leftField.dataType, rightField.dataType),
nullable = leftField.nullable || rightField.nullable)
- }
- StructType(newFields)
- })
-
- private[sql] def merge(left: DataType, right: DataType, caseSensitive: Boolean = true): DataType =
- mergeInternal(left, right, (s1: StructType, s2: StructType) => {
- val leftFields = s1.fields
- val rightFields = s2.fields
- val newFields = mutable.ArrayBuffer.empty[StructField]
+ }
+ StructType(newFields)
+ })
- def normalize(name: String): String = {
- if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
- }
+ private[sql] def merge(
+ left: DataType,
+ right: DataType,
+ caseSensitive: Boolean = true): DataType =
+ mergeInternal(
+ left,
+ right,
+ (s1: StructType, s2: StructType) => {
+ val leftFields = s1.fields
+ val rightFields = s2.fields
+ val newFields = mutable.ArrayBuffer.empty[StructField]
+
+ def normalize(name: String): String = {
+ if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
+ }
- val rightMapped = fieldsMap(rightFields, caseSensitive)
- leftFields.foreach {
- case leftField @ StructField(leftName, leftType, leftNullable, _) =>
- rightMapped.get(normalize(leftName))
+ val rightMapped = fieldsMap(rightFields, caseSensitive)
+ leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) =>
+ rightMapped
+ .get(normalize(leftName))
.map { case rightField @ StructField(rightName, rightType, rightNullable, _) =>
try {
leftField.copy(
@@ -610,39 +604,40 @@ object StructType extends AbstractDataType {
nullable = leftNullable || rightNullable)
} catch {
case NonFatal(e) =>
- throw DataTypeErrors.cannotMergeIncompatibleDataTypesError(
- leftType, rightType)
+ throw DataTypeErrors.cannotMergeIncompatibleDataTypesError(leftType, rightType)
}
}
.orElse {
Some(leftField)
}
.foreach(newFields += _)
- }
-
- val leftMapped = fieldsMap(leftFields, caseSensitive)
- rightFields
- .filterNot(f => leftMapped.contains(normalize(f.name)))
- .foreach { f =>
- newFields += f
}
- StructType(newFields.toArray)
- })
+ val leftMapped = fieldsMap(leftFields, caseSensitive)
+ rightFields
+ .filterNot(f => leftMapped.contains(normalize(f.name)))
+ .foreach { f =>
+ newFields += f
+ }
+
+ StructType(newFields.toArray)
+ })
private def mergeInternal(
left: DataType,
right: DataType,
mergeStruct: (StructType, StructType) => StructType): DataType =
(left, right) match {
- case (ArrayType(leftElementType, leftContainsNull),
- ArrayType(rightElementType, rightContainsNull)) =>
+ case (
+ ArrayType(leftElementType, leftContainsNull),
+ ArrayType(rightElementType, rightContainsNull)) =>
ArrayType(
mergeInternal(leftElementType, rightElementType, mergeStruct),
leftContainsNull || rightContainsNull)
- case (MapType(leftKeyType, leftValueType, leftContainsNull),
- MapType(rightKeyType, rightValueType, rightContainsNull)) =>
+ case (
+ MapType(leftKeyType, leftValueType, leftContainsNull),
+ MapType(rightKeyType, rightValueType, rightContainsNull)) =>
MapType(
mergeInternal(leftKeyType, rightKeyType, mergeStruct),
mergeInternal(leftValueType, rightValueType, mergeStruct),
@@ -650,17 +645,20 @@ object StructType extends AbstractDataType {
case (s1: StructType, s2: StructType) => mergeStruct(s1, s2)
- case (DecimalType.Fixed(leftPrecision, leftScale),
- DecimalType.Fixed(rightPrecision, rightScale)) =>
+ case (
+ DecimalType.Fixed(leftPrecision, leftScale),
+ DecimalType.Fixed(rightPrecision, rightScale)) =>
if (leftScale == rightScale) {
DecimalType(leftPrecision.max(rightPrecision), leftScale)
} else {
throw DataTypeErrors.cannotMergeDecimalTypesWithIncompatibleScaleError(
- leftScale, rightScale)
+ leftScale,
+ rightScale)
}
case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
- if leftUdt.userClass == rightUdt.userClass => leftUdt
+ if leftUdt.userClass == rightUdt.userClass =>
+ leftUdt
case (YearMonthIntervalType(lstart, lend), YearMonthIntervalType(rstart, rend)) =>
YearMonthIntervalType(Math.min(lstart, rstart).toByte, Math.max(lend, rend).toByte)
@@ -706,10 +704,12 @@ object StructType extends AbstractDataType {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
- !found.get.dataType.sameType(field.dataType)) {
+ !found.get.dataType.sameType(field.dataType)) {
// Found a field with same name, but different data type.
- findMissingFields(found.get.dataType.asInstanceOf[StructType],
- field.dataType.asInstanceOf[StructType], resolver).map { missingType =>
+ findMissingFields(
+ found.get.dataType.asInstanceOf[StructType],
+ field.dataType.asInstanceOf[StructType],
+ resolver).map { missingType =>
newFields += found.get.copy(dataType = missingType)
}
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala
index 9968d75dd2577..b08d16f0e2c97 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampNTZType.scala
@@ -20,16 +20,17 @@ package org.apache.spark.sql.types
import org.apache.spark.annotation.Unstable
/**
- * The timestamp without time zone type represents a local time in microsecond precision,
- * which is independent of time zone.
- * Its valid range is [0001-01-01T00:00:00.000000, 9999-12-31T23:59:59.999999].
- * To represent an absolute point in time, use `TimestampType` instead.
+ * The timestamp without time zone type represents a local time in microsecond precision, which is
+ * independent of time zone. Its valid range is [0001-01-01T00:00:00.000000,
+ * 9999-12-31T23:59:59.999999]. To represent an absolute point in time, use `TimestampType`
+ * instead.
*
* Please use the singleton `DataTypes.TimestampNTZType` to refer the type.
* @since 3.4.0
*/
@Unstable
-class TimestampNTZType private() extends DatetimeType {
+class TimestampNTZType private () extends DatetimeType {
+
/**
* The default size of a value of the TimestampNTZType is 8 bytes.
*/
@@ -42,9 +43,9 @@ class TimestampNTZType private() extends DatetimeType {
/**
* The companion case object and its class is separated so the companion object also subclasses
- * the TimestampNTZType class. Otherwise, the companion object would be of type
- * "TimestampNTZType" in byte code. Defined with a private constructor so the companion
- * object is the only possible instantiation.
+ * the TimestampNTZType class. Otherwise, the companion object would be of type "TimestampNTZType"
+ * in byte code. Defined with a private constructor so the companion object is the only possible
+ * instantiation.
*
* @since 3.4.0
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
index 1185e4a9e32ca..bf869d1f38c57 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
@@ -20,16 +20,16 @@ package org.apache.spark.sql.types
import org.apache.spark.annotation.Stable
/**
- * The timestamp type represents a time instant in microsecond precision.
- * Valid range is [0001-01-01T00:00:00.000000Z, 9999-12-31T23:59:59.999999Z] where
- * the left/right-bound is a date and time of the proleptic Gregorian
- * calendar in UTC+00:00.
+ * The timestamp type represents a time instant in microsecond precision. Valid range is
+ * [0001-01-01T00:00:00.000000Z, 9999-12-31T23:59:59.999999Z] where the left/right-bound is a date
+ * and time of the proleptic Gregorian calendar in UTC+00:00.
*
* Please use the singleton `DataTypes.TimestampType` to refer the type.
* @since 1.3.0
*/
@Stable
-class TimestampType private() extends DatetimeType {
+class TimestampType private () extends DatetimeType {
+
/**
* The default size of a value of the TimestampType is 8 bytes.
*/
@@ -40,8 +40,8 @@ class TimestampType private() extends DatetimeType {
/**
* The companion case object and its class is separated so the companion object also subclasses
- * the TimestampType class. Otherwise, the companion object would be of type "TimestampType$"
- * in byte code. Defined with a private constructor so the companion object is the only possible
+ * the TimestampType class. Otherwise, the companion object would be of type "TimestampType$" in
+ * byte code. Defined with a private constructor so the companion object is the only possible
* instantiation.
*
* @since 1.3.0
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala
index 9219c1d139b99..85d421a07577b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala
@@ -45,21 +45,26 @@ object UDTRegistration extends Serializable with Logging {
/**
* Queries if a given user class is already registered or not.
- * @param userClassName the name of user class
- * @return boolean value indicates if the given user class is registered or not
+ * @param userClassName
+ * the name of user class
+ * @return
+ * boolean value indicates if the given user class is registered or not
*/
def exists(userClassName: String): Boolean = udtMap.contains(userClassName)
/**
- * Registers an UserDefinedType to an user class. If the user class is already registered
- * with another UserDefinedType, warning log message will be shown.
- * @param userClass the name of user class
- * @param udtClass the name of UserDefinedType class for the given userClass
+ * Registers an UserDefinedType to an user class. If the user class is already registered with
+ * another UserDefinedType, warning log message will be shown.
+ * @param userClass
+ * the name of user class
+ * @param udtClass
+ * the name of UserDefinedType class for the given userClass
*/
def register(userClass: String, udtClass: String): Unit = {
if (udtMap.contains(userClass)) {
- logWarning(log"Cannot register UDT for ${MDC(LogKeys.CLASS_NAME, userClass)}, " +
- log"which is already registered.")
+ logWarning(
+ log"Cannot register UDT for ${MDC(LogKeys.CLASS_NAME, userClass)}, " +
+ log"which is already registered.")
} else {
// When register UDT with class name, we can't check if the UDT class is an UserDefinedType,
// or not. The check is deferred.
@@ -69,8 +74,10 @@ object UDTRegistration extends Serializable with Logging {
/**
* Returns the Class of UserDefinedType for the name of a given user class.
- * @param userClass class name of user class
- * @return Option value of the Class object of UserDefinedType
+ * @param userClass
+ * class name of user class
+ * @return
+ * Option value of the Class object of UserDefinedType
*/
def getUDTFor(userClass: String): Option[Class[_]] = {
udtMap.get(userClass).map { udtClassName =>
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala
index 7ec00bde0b25f..4993e249b3059 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala
@@ -22,13 +22,8 @@ package org.apache.spark.sql.types
private[sql] object UpCastRule {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
- val numericPrecedence: IndexedSeq[NumericType] = IndexedSeq(
- ByteType,
- ShortType,
- IntegerType,
- LongType,
- FloatType,
- DoubleType)
+ val numericPrecedence: IndexedSeq[NumericType] =
+ IndexedSeq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
/**
* Returns true iff we can safely up-cast the `from` type to `to` type without any truncating or
@@ -62,10 +57,9 @@ private[sql] object UpCastRule {
case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
- fromFields.zip(toFields).forall {
- case (f1, f2) =>
- resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType)
- }
+ fromFields.zip(toFields).forall { case (f1, f2) =>
+ resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType)
+ }
case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index 5cbd876b31e68..dd8ca26c52462 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -27,15 +27,14 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
/**
* The data type for User Defined Types (UDTs).
*
- * This interface allows a user to make their own classes more interoperable with SparkSQL;
- * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
- * a `DataFrame` which has class X in the schema.
+ * This interface allows a user to make their own classes more interoperable with SparkSQL; e.g.,
+ * by creating a [[UserDefinedType]] for a class X, it becomes possible to create a `DataFrame`
+ * which has class X in the schema.
*
- * For SparkSQL to recognize UDTs, the UDT must be annotated with
- * [[SQLUserDefinedType]].
+ * For SparkSQL to recognize UDTs, the UDT must be annotated with [[SQLUserDefinedType]].
*
- * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD.
- * The conversion via `deserialize` occurs when reading from a `DataFrame`.
+ * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. The
+ * conversion via `deserialize` occurs when reading from a `DataFrame`.
*/
@DeveloperApi
@Since("3.2.0")
@@ -81,7 +80,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
override private[sql] def acceptsType(dataType: DataType): Boolean = dataType match {
case other: UserDefinedType[_] if this.userClass != null && other.userClass != null =>
this.getClass == other.getClass ||
- this.userClass.isAssignableFrom(other.userClass)
+ this.userClass.isAssignableFrom(other.userClass)
case _ => false
}
@@ -98,6 +97,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
}
private[spark] object UserDefinedType {
+
/**
* Get the sqlType of a (potential) [[UserDefinedType]].
*/
@@ -115,7 +115,8 @@ private[spark] object UserDefinedType {
private[sql] class PythonUserDefinedType(
val sqlType: DataType,
override val pyUDT: String,
- override val serializedPyClass: String) extends UserDefinedType[Any] {
+ override val serializedPyClass: String)
+ extends UserDefinedType[Any] {
/* The serialization is handled by UDT class in Python */
override def serialize(obj: Any): Any = obj
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala
index 103fe7a59fc83..4d775c3e1e390 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.types
import org.apache.spark.annotation.Unstable
/**
- * The data type representing semi-structured values with arbitrary hierarchical data structures. It
- * is intended to store parsed JSON values and most other data types in the system (e.g., it cannot
- * store a map with a non-string key type).
+ * The data type representing semi-structured values with arbitrary hierarchical data structures.
+ * It is intended to store parsed JSON values and most other data types in the system (e.g., it
+ * cannot store a map with a non-string key type).
*
* @since 4.0.0
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala
index 6532a3b220c5b..f69054f2c1fbc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/YearMonthIntervalType.scala
@@ -29,18 +29,19 @@ import org.apache.spark.sql.types.YearMonthIntervalType.fieldToString
*
* `YearMonthIntervalType` represents positive as well as negative year-month intervals.
*
- * @param startField The leftmost field which the type comprises of. Valid values:
- * 0 (YEAR), 1 (MONTH).
- * @param endField The rightmost field which the type comprises of. Valid values:
- * 0 (YEAR), 1 (MONTH).
+ * @param startField
+ * The leftmost field which the type comprises of. Valid values: 0 (YEAR), 1 (MONTH).
+ * @param endField
+ * The rightmost field which the type comprises of. Valid values: 0 (YEAR), 1 (MONTH).
*
* @since 3.2.0
*/
@Unstable
case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType {
+
/**
- * Year-month interval values always occupy 4 bytes.
- * The YEAR field is constrained by the upper bound 178956970 to fit to `Int`.
+ * Year-month interval values always occupy 4 bytes. The YEAR field is constrained by the upper
+ * bound 178956970 to fit to `Int`.
*/
override def defaultSize: Int = 4
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 6852fe09ef96b..1740cbe2957b8 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -38,33 +38,33 @@ private[sql] object ArrowUtils {
// todo: support more types.
/** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
- def toArrowType(
- dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = dt match {
- case BooleanType => ArrowType.Bool.INSTANCE
- case ByteType => new ArrowType.Int(8, true)
- case ShortType => new ArrowType.Int(8 * 2, true)
- case IntegerType => new ArrowType.Int(8 * 4, true)
- case LongType => new ArrowType.Int(8 * 8, true)
- case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
- case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
- case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE
- case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE
- case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE
- case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE
- case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale, 8 * 16)
- case DateType => new ArrowType.Date(DateUnit.DAY)
- case TimestampType if timeZoneId == null =>
- throw SparkException.internalError("Missing timezoneId where it is mandatory.")
- case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
- case TimestampNTZType =>
- new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
- case NullType => ArrowType.Null.INSTANCE
- case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
- case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
- case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)
- case _ =>
- throw ExecutionErrors.unsupportedDataTypeError(dt)
- }
+ def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType =
+ dt match {
+ case BooleanType => ArrowType.Bool.INSTANCE
+ case ByteType => new ArrowType.Int(8, true)
+ case ShortType => new ArrowType.Int(8 * 2, true)
+ case IntegerType => new ArrowType.Int(8 * 4, true)
+ case LongType => new ArrowType.Int(8 * 8, true)
+ case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+ case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+ case _: StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE
+ case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE
+ case _: StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE
+ case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE
+ case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale, 8 * 16)
+ case DateType => new ArrowType.Date(DateUnit.DAY)
+ case TimestampType if timeZoneId == null =>
+ throw SparkException.internalError("Missing timezoneId where it is mandatory.")
+ case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
+ case TimestampNTZType =>
+ new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
+ case NullType => ArrowType.Null.INSTANCE
+ case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
+ case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
+ case CalendarIntervalType => new ArrowType.Interval(IntervalUnit.MONTH_DAY_NANO)
+ case _ =>
+ throw ExecutionErrors.unsupportedDataTypeError(dt)
+ }
def fromArrowType(dt: ArrowType): DataType = dt match {
case ArrowType.Bool.INSTANCE => BooleanType
@@ -73,9 +73,11 @@ private[sql] object ArrowUtils {
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType
case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
case float: ArrowType.FloatingPoint
- if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType
+ if float.getPrecision() == FloatingPointPrecision.SINGLE =>
+ FloatType
case float: ArrowType.FloatingPoint
- if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
+ if float.getPrecision() == FloatingPointPrecision.DOUBLE =>
+ DoubleType
case ArrowType.Utf8.INSTANCE => StringType
case ArrowType.Binary.INSTANCE => BinaryType
case ArrowType.LargeUtf8.INSTANCE => StringType
@@ -83,13 +85,15 @@ private[sql] object ArrowUtils {
case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
case ts: ArrowType.Timestamp
- if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => TimestampNTZType
+ if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
+ TimestampNTZType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
case ArrowType.Null.INSTANCE => NullType
- case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType()
+ case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
+ YearMonthIntervalType()
case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType()
- case ci: ArrowType.Interval
- if ci.getUnit == IntervalUnit.MONTH_DAY_NANO => CalendarIntervalType
+ case ci: ArrowType.Interval if ci.getUnit == IntervalUnit.MONTH_DAY_NANO =>
+ CalendarIntervalType
case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt)
}
@@ -103,37 +107,54 @@ private[sql] object ArrowUtils {
dt match {
case ArrayType(elementType, containsNull) =>
val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
- new Field(name, fieldType,
- Seq(toArrowField("element", elementType, containsNull, timeZoneId,
- largeVarTypes)).asJava)
+ new Field(
+ name,
+ fieldType,
+ Seq(
+ toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava)
case StructType(fields) =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
- new Field(name, fieldType,
- fields.map { field =>
- toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes)
- }.toImmutableArraySeq.asJava)
+ new Field(
+ name,
+ fieldType,
+ fields
+ .map { field =>
+ toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes)
+ }
+ .toImmutableArraySeq
+ .asJava)
case MapType(keyType, valueType, valueContainsNull) =>
val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
// Note: Map Type struct can not be null, Struct Type key field can not be null
- new Field(name, mapType,
- Seq(toArrowField(MapVector.DATA_VECTOR_NAME,
- new StructType()
- .add(MapVector.KEY_NAME, keyType, nullable = false)
- .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull),
- nullable = false,
- timeZoneId,
- largeVarTypes)).asJava)
+ new Field(
+ name,
+ mapType,
+ Seq(
+ toArrowField(
+ MapVector.DATA_VECTOR_NAME,
+ new StructType()
+ .add(MapVector.KEY_NAME, keyType, nullable = false)
+ .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull),
+ nullable = false,
+ timeZoneId,
+ largeVarTypes)).asJava)
case udt: UserDefinedType[_] =>
toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
case _: VariantType =>
- val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null,
+ val fieldType = new FieldType(
+ nullable,
+ ArrowType.Struct.INSTANCE,
+ null,
Map("variant" -> "true").asJava)
- new Field(name, fieldType,
- Seq(toArrowField("value", BinaryType, false, timeZoneId, largeVarTypes),
+ new Field(
+ name,
+ fieldType,
+ Seq(
+ toArrowField("value", BinaryType, false, timeZoneId, largeVarTypes),
toArrowField("metadata", BinaryType, false, timeZoneId, largeVarTypes)).asJava)
case dataType =>
- val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId,
- largeVarTypes), null)
+ val fieldType =
+ new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null)
new Field(name, fieldType, Seq.empty[Field].asJava)
}
}
@@ -149,9 +170,12 @@ private[sql] object ArrowUtils {
val elementField = field.getChildren().get(0)
val elementType = fromArrowField(elementField)
ArrayType(elementType, containsNull = elementField.isNullable)
- case ArrowType.Struct.INSTANCE if field.getMetadata.getOrDefault("variant", "") == "true"
- && field.getChildren.asScala.map(_.getName).asJava
- .containsAll(Seq("value", "metadata").asJava) =>
+ case ArrowType.Struct.INSTANCE
+ if field.getMetadata.getOrDefault("variant", "") == "true"
+ && field.getChildren.asScala
+ .map(_.getName)
+ .asJava
+ .containsAll(Seq("value", "metadata").asJava) =>
VariantType
case ArrowType.Struct.INSTANCE =>
val fields = field.getChildren().asScala.map { child =>
@@ -163,7 +187,9 @@ private[sql] object ArrowUtils {
}
}
- /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */
+ /**
+ * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType
+ */
def toArrowSchema(
schema: StructType,
timeZoneId: String,
@@ -187,14 +213,17 @@ private[sql] object ArrowUtils {
}
private def deduplicateFieldNames(
- dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match {
- case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames)
+ dt: DataType,
+ errorOnDuplicatedFieldNames: Boolean): DataType = dt match {
+ case udt: UserDefinedType[_] =>
+ deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames)
case st @ StructType(fields) =>
val newNames = if (st.names.toSet.size == st.names.length) {
st.names
} else {
if (errorOnDuplicatedFieldNames) {
- throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names.toImmutableArraySeq)
+ throw ExecutionErrors.duplicatedFieldNameInArrowStructError(
+ st.names.toImmutableArraySeq)
}
val genNawName = st.names.groupBy(identity).map {
case (name, names) if names.length > 1 =>
@@ -207,7 +236,10 @@ private[sql] object ArrowUtils {
val newFields =
fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) =>
StructField(
- name, deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), nullable, metadata)
+ name,
+ deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames),
+ nullable,
+ metadata)
}
StructType(newFields)
case ArrayType(elementType, containsNull) =>
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
index 07a9409bc57a2..18646f67975c0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionImplUtils.java
@@ -17,20 +17,25 @@
package org.apache.spark.sql.catalyst.expressions;
-import org.apache.spark.SparkBuildInfo;
-import org.apache.spark.sql.errors.QueryExecutionErrors;
-import org.apache.spark.unsafe.types.UTF8String;
-import org.apache.spark.util.VersionUtils;
-
-import javax.crypto.Cipher;
-import javax.crypto.spec.GCMParameterSpec;
-import javax.crypto.spec.IvParameterSpec;
-import javax.crypto.spec.SecretKeySpec;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
+import java.text.BreakIterator;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+import javax.crypto.Cipher;
+import javax.crypto.spec.GCMParameterSpec;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+import org.apache.spark.SparkBuildInfo;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.errors.QueryExecutionErrors;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.VersionUtils;
/**
* A utility class for constructing expressions.
@@ -272,4 +277,42 @@ private static byte[] aesInternal(
throw QueryExecutionErrors.aesCryptoError(e.getMessage());
}
}
+
+ public static ArrayData getSentences(
+ UTF8String str,
+ UTF8String language,
+ UTF8String country) {
+ if (str == null) return null;
+ Locale locale;
+ if (language != null && country != null) {
+ locale = new Locale(language.toString(), country.toString());
+ } else if (language != null) {
+ locale = new Locale(language.toString());
+ } else {
+ locale = Locale.US;
+ }
+ String sentences = str.toString();
+ BreakIterator sentenceInstance = BreakIterator.getSentenceInstance(locale);
+ sentenceInstance.setText(sentences);
+
+ int sentenceIndex = 0;
+ List res = new ArrayList<>();
+ while (sentenceInstance.next() != BreakIterator.DONE) {
+ String sentence = sentences.substring(sentenceIndex, sentenceInstance.current());
+ sentenceIndex = sentenceInstance.current();
+ BreakIterator wordInstance = BreakIterator.getWordInstance(locale);
+ wordInstance.setText(sentence);
+ int wordIndex = 0;
+ List words = new ArrayList<>();
+ while (wordInstance.next() != BreakIterator.DONE) {
+ String word = sentence.substring(wordIndex, wordInstance.current());
+ wordIndex = wordInstance.current();
+ if (Character.isLetterOrDigit(word.charAt(0))) {
+ words.add(UTF8String.fromString(word));
+ }
+ }
+ res.add(new GenericArrayData(words.toArray(new UTF8String[0])));
+ }
+ return new GenericArrayData(res.toArray(new GenericArrayData[0]));
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java
index b191438dbc3ee..8b32940d7a657 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java
@@ -53,7 +53,7 @@ static Column create(
boolean nullable,
String comment,
String metadataInJSON) {
- return new ColumnImpl(name, dataType, nullable, comment, null, null, metadataInJSON);
+ return new ColumnImpl(name, dataType, nullable, comment, null, null, null, metadataInJSON);
}
static Column create(
@@ -63,7 +63,8 @@ static Column create(
String comment,
ColumnDefaultValue defaultValue,
String metadataInJSON) {
- return new ColumnImpl(name, dataType, nullable, comment, defaultValue, null, metadataInJSON);
+ return new ColumnImpl(name, dataType, nullable, comment, defaultValue,
+ null, null, metadataInJSON);
}
static Column create(
@@ -74,7 +75,18 @@ static Column create(
String generationExpression,
String metadataInJSON) {
return new ColumnImpl(name, dataType, nullable, comment, null,
- generationExpression, metadataInJSON);
+ generationExpression, null, metadataInJSON);
+ }
+
+ static Column create(
+ String name,
+ DataType dataType,
+ boolean nullable,
+ String comment,
+ IdentityColumnSpec identityColumnSpec,
+ String metadataInJSON) {
+ return new ColumnImpl(name, dataType, nullable, comment, null,
+ null, identityColumnSpec, metadataInJSON);
}
/**
@@ -113,6 +125,12 @@ static Column create(
@Nullable
String generationExpression();
+ /**
+ * Returns the identity column specification of this table column. Null means no identity column.
+ */
+ @Nullable
+ IdentityColumnSpec identityColumnSpec();
+
/**
* Returns the column metadata in JSON format.
*/
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java
index 5ccb15ff1f0a4..dceac1b484cf2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalogCapability.java
@@ -59,5 +59,23 @@ public enum TableCatalogCapability {
* {@link TableCatalog#createTable}.
* See {@link Column#defaultValue()}.
*/
- SUPPORT_COLUMN_DEFAULT_VALUE
+ SUPPORT_COLUMN_DEFAULT_VALUE,
+
+ /**
+ * Signals that the TableCatalog supports defining identity columns upon table creation in SQL.
+ *
+ * Without this capability, any create/replace table statements with an identity column defined
+ * in the table schema will throw an exception during analysis.
+ *
+ * An identity column is defined with syntax:
+ * {@code colName colType GENERATED ALWAYS AS IDENTITY(identityColumnSpec)}
+ * or
+ * {@code colName colType GENERATED BY DEFAULT AS IDENTITY(identityColumnSpec)}
+ * identityColumnSpec is defined with syntax: {@code [START WITH start | INCREMENT BY step]*}
+ *
+ * IdentitySpec is included in the column definition for APIs like
+ * {@link TableCatalog#createTable}.
+ * See {@link Column#identityColumnSpec()}.
+ */
+ SUPPORTS_CREATE_TABLE_WITH_IDENTITY_COLUMNS
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
index 90d531ae21892..18c76833c5879 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
@@ -32,6 +32,11 @@
*/
@Evolving
public interface ProcedureParameter {
+ /**
+ * A field metadata key that indicates whether an argument is passed by name.
+ */
+ String BY_NAME_METADATA_KEY = "BY_NAME";
+
/**
* Creates a builder for an IN procedure parameter.
*
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java
index ee9a09055243b..1a91fd21bf07e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java
@@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure {
* validate if the input types are compatible while binding or delegate that to Spark. Regardless,
* Spark will always perform the final validation of the arguments and rearrange them as needed
* based on {@link BoundProcedure#parameters() reported parameters}.
+ *
+ * The provided {@code inputType} is based on the procedure arguments. If an argument is passed
+ * by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY}
+ * set to {@code true}. In such cases, the field name will match the name of the target procedure
+ * parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will
+ * not be set and the name will be assigned randomly.
*
* @param inputType the input types to bind to
* @return the bound procedure that is most suitable for the given input types
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
deleted file mode 100644
index 9b95f74db3a49..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ /dev/null
@@ -1,359 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.lang.reflect.Modifier
-
-import scala.reflect.{classTag, ClassTag}
-import scala.reflect.runtime.universe.TypeTag
-
-import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast}
-import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer}
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types._
-
-/**
- * Methods for creating an [[Encoder]].
- *
- * @since 1.6.0
- */
-object Encoders {
-
- /**
- * An encoder for nullable boolean type.
- * The Scala primitive encoder is available as [[scalaBoolean]].
- * @since 1.6.0
- */
- def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
-
- /**
- * An encoder for nullable byte type.
- * The Scala primitive encoder is available as [[scalaByte]].
- * @since 1.6.0
- */
- def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
-
- /**
- * An encoder for nullable short type.
- * The Scala primitive encoder is available as [[scalaShort]].
- * @since 1.6.0
- */
- def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
-
- /**
- * An encoder for nullable int type.
- * The Scala primitive encoder is available as [[scalaInt]].
- * @since 1.6.0
- */
- def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
-
- /**
- * An encoder for nullable long type.
- * The Scala primitive encoder is available as [[scalaLong]].
- * @since 1.6.0
- */
- def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
-
- /**
- * An encoder for nullable float type.
- * The Scala primitive encoder is available as [[scalaFloat]].
- * @since 1.6.0
- */
- def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
-
- /**
- * An encoder for nullable double type.
- * The Scala primitive encoder is available as [[scalaDouble]].
- * @since 1.6.0
- */
- def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
-
- /**
- * An encoder for nullable string type.
- *
- * @since 1.6.0
- */
- def STRING: Encoder[java.lang.String] = ExpressionEncoder()
-
- /**
- * An encoder for nullable decimal type.
- *
- * @since 1.6.0
- */
- def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
-
- /**
- * An encoder for nullable date type.
- *
- * @since 1.6.0
- */
- def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.LocalDate` class
- * to the internal representation of nullable Catalyst's DateType.
- *
- * @since 3.0.0
- */
- def LOCALDATE: Encoder[java.time.LocalDate] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.LocalDateTime` class
- * to the internal representation of nullable Catalyst's TimestampNTZType.
- *
- * @since 3.4.0
- */
- def LOCALDATETIME: Encoder[java.time.LocalDateTime] = ExpressionEncoder()
-
- /**
- * An encoder for nullable timestamp type.
- *
- * @since 1.6.0
- */
- def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.Instant` class
- * to the internal representation of nullable Catalyst's TimestampType.
- *
- * @since 3.0.0
- */
- def INSTANT: Encoder[java.time.Instant] = ExpressionEncoder()
-
- /**
- * An encoder for arrays of bytes.
- *
- * @since 1.6.1
- */
- def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.Duration` class
- * to the internal representation of nullable Catalyst's DayTimeIntervalType.
- *
- * @since 3.2.0
- */
- def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()
-
- /**
- * Creates an encoder that serializes instances of the `java.time.Period` class
- * to the internal representation of nullable Catalyst's YearMonthIntervalType.
- *
- * @since 3.2.0
- */
- def PERIOD: Encoder[java.time.Period] = ExpressionEncoder()
-
- /**
- * Creates an encoder for Java Bean of type T.
- *
- * T must be publicly accessible.
- *
- * supported types for java bean field:
- * - primitive types: boolean, int, double, etc.
- * - boxed types: Boolean, Integer, Double, etc.
- * - String
- * - java.math.BigDecimal, java.math.BigInteger
- * - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, java.time.Instant
- * - collection types: array, java.util.List, and map
- * - nested java bean.
- *
- * @since 1.6.0
- */
- def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
-
- /**
- * Creates a [[Row]] encoder for schema `schema`.
- *
- * @since 3.5.0
- */
- def row(schema: StructType): Encoder[Row] = ExpressionEncoder(schema)
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true)
-
- /**
- * Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java
- * serialization. This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @note This is extremely inefficient and should only be used as the last resort.
- *
- * @since 1.6.0
- */
- def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false)
-
- /**
- * Creates an encoder that serializes objects of type T using generic Java serialization.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @note This is extremely inefficient and should only be used as the last resort.
- *
- * @since 1.6.0
- */
- def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))
-
- /** Throws an exception if T is not a public class. */
- private def validatePublicClass[T: ClassTag](): Unit = {
- if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
- throw QueryExecutionErrors.notPublicClassError(classTag[T].runtimeClass.getName)
- }
- }
-
- /** A way to construct encoders using generic serializers. */
- private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = {
- if (classTag[T].runtimeClass.isPrimitive) {
- throw QueryExecutionErrors.primitiveTypesNotSupportedError()
- }
-
- validatePublicClass[T]()
-
- ExpressionEncoder[T](
- objSerializer =
- EncodeUsingSerializer(
- BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo),
- objDeserializer =
- DecodeUsingSerializer[T](
- Cast(GetColumnByOrdinal(0, BinaryType), BinaryType),
- classTag[T],
- kryo = useKryo),
- clsTag = classTag[T]
- )
- }
-
- /**
- * An encoder for 2-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2](
- e1: Encoder[T1],
- e2: Encoder[T2]): Encoder[(T1, T2)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
- }
-
- /**
- * An encoder for 3-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2, T3](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
- }
-
- /**
- * An encoder for 4-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
- }
-
- /**
- * An encoder for 5-ary tuples.
- *
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4, T5](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4],
- e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
- ExpressionEncoder.tuple(
- encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
- }
-
- /**
- * An encoder for Scala's product type (tuples, case classes, etc).
- * @since 2.0.0
- */
- def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive int type.
- * @since 2.0.0
- */
- def scalaInt: Encoder[Int] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive long type.
- * @since 2.0.0
- */
- def scalaLong: Encoder[Long] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive double type.
- * @since 2.0.0
- */
- def scalaDouble: Encoder[Double] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive float type.
- * @since 2.0.0
- */
- def scalaFloat: Encoder[Float] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive byte type.
- * @since 2.0.0
- */
- def scalaByte: Encoder[Byte] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive short type.
- * @since 2.0.0
- */
- def scalaShort: Encoder[Short] = ExpressionEncoder()
-
- /**
- * An encoder for Scala's primitive boolean type.
- * @since 2.0.0
- */
- def scalaBoolean: Encoder[Boolean] = ExpressionEncoder()
-
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 40b49506b58aa..4752434015375 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._
@@ -411,6 +411,12 @@ object DeserializerBuildHelper {
val result = InitializeJavaBean(newInstance, setters.toMap)
exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result)
+ case TransformingEncoder(tag, _, codec) if codec == JavaSerializationCodec =>
+ DecodeUsingSerializer(path, tag, kryo = false)
+
+ case TransformingEncoder(tag, _, codec) if codec == KryoSerializationCodec =>
+ DecodeUsingSerializer(path, tag, kryo = true)
+
case TransformingEncoder(tag, encoder, provider) =>
Invoke(
Literal.create(provider(), ObjectType(classOf[Codec[_, _]])),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
index 408bd65333cac..20cf80e88e42a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
@@ -26,7 +26,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
* An [[InternalRow]] that projects particular columns from another [[InternalRow]] without copying
* the underlying data.
*/
-case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) extends InternalRow {
+case class ProjectingInternalRow(schema: StructType,
+ colOrdinals: IndexedSeq[Int]) extends InternalRow {
assert(schema.size == colOrdinals.size)
private var row: InternalRow = _
@@ -116,3 +117,9 @@ case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) exte
row.get(colOrdinals(ordinal), dataType)
}
}
+
+object ProjectingInternalRow {
+ def apply(schema: StructType, colOrdinals: Seq[Int]): ProjectingInternalRow = {
+ new ProjectingInternalRow(schema, colOrdinals.toIndexedSeq)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 38bf0651d6f1c..daebe15c298f6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -21,7 +21,7 @@ import scala.language.existentials
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
@@ -398,6 +398,12 @@ object SerializerBuildHelper {
}
createSerializerForObject(input, serializedFields)
+ case TransformingEncoder(_, _, codec) if codec == JavaSerializationCodec =>
+ EncodeUsingSerializer(input, kryo = false)
+
+ case TransformingEncoder(_, _, codec) if codec == KryoSerializationCodec =>
+ EncodeUsingSerializer(input, kryo = true)
+
case TransformingEncoder(_, encoder, codecProvider) =>
val encoded = Invoke(
Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0164af945ca28..9e5b1d1254c87 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Random, Success, Try}
-import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
+import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
@@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction}
+import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ExtractGenerator ::
ResolveGenerate ::
ResolveFunctions ::
+ ResolveProcedures ::
+ BindProcedures ::
ResolveTableSpec ::
ResolveAliases ::
ResolveSubquery ::
@@ -2611,6 +2614,66 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}
+ /**
+ * A rule that resolves procedures.
+ */
+ object ResolveProcedures extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
+ _.containsPattern(UNRESOLVED_PROCEDURE), ruleId) {
+ case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args, execute) =>
+ val procedureCatalog = catalog.asProcedureCatalog
+ val procedure = load(procedureCatalog, ident)
+ Call(ResolvedProcedure(procedureCatalog, ident, procedure), args, execute)
+ }
+
+ private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = {
+ try {
+ catalog.loadProcedure(ident)
+ } catch {
+ case e: Exception if !e.isInstanceOf[SparkThrowable] =>
+ val nameParts = catalog.name +: ident.asMultipartIdentifier
+ throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e)
+ }
+ }
+ }
+
+ /**
+ * A rule that binds procedures to the input types and rearranges arguments as needed.
+ */
+ object BindProcedures extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure), args, execute)
+ if args.forall(_.resolved) =>
+ val inputType = extractInputType(args)
+ val bound = unbound.bind(inputType)
+ validateParameterModes(bound)
+ val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args)
+ Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute)
+ }
+
+ private def extractInputType(args: Seq[Expression]): StructType = {
+ val fields = args.zipWithIndex.map {
+ case (NamedArgumentExpression(name, value), _) =>
+ StructField(name, value.dataType, value.nullable, byNameMetadata)
+ case (arg, index) =>
+ StructField(s"param$index", arg.dataType, arg.nullable)
+ }
+ StructType(fields)
+ }
+
+ private def byNameMetadata: Metadata = {
+ new MetadataBuilder()
+ .putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, value = true)
+ .build()
+ }
+
+ private def validateParameterModes(procedure: BoundProcedure): Unit = {
+ procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param =>
+ throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}")
+ }
+ }
+ }
+
/**
* This rule resolves and rewrites subqueries inside expressions.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index 17b1c4e249f57..3afe0ec8e9a7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
+ ProcedureArgumentCoercion ::
new AnsiCombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a9fbe548ba39e..5a9d5cd87ecc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -250,6 +250,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
context = u.origin.getQueryContext,
summary = u.origin.context.summary)
+ case u: UnresolvedInlineTable if unresolvedInlineTableContainsScalarSubquery(u) =>
+ throw QueryCompilationErrors.inlineTableContainsScalarSubquery(u)
+
case command: V2PartitionCommand =>
command.table match {
case r @ ResolvedTable(_, _, table, _) => table match {
@@ -673,6 +676,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
varName,
c.defaultExpr.originalSQL)
+ case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure =>
+ c.checkArgTypes() match {
+ case mismatch: TypeCheckResult.DataTypeMismatch =>
+ c.dataTypeMismatch("CALL", mismatch)
+ case _ =>
+ throw SparkException.internalError("Invalid input for procedure")
+ }
+
case _ => // Falls back to the following checks
}
@@ -1559,6 +1570,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case _ =>
}
}
+
+ private def unresolvedInlineTableContainsScalarSubquery(
+ unresolvedInlineTable: UnresolvedInlineTable) = {
+ unresolvedInlineTable.rows.exists { row =>
+ row.exists { expression =>
+ expression.exists(_.isInstanceOf[ScalarSubquery])
+ }
+ }
+ }
}
// a heap of the preempted error that only keeps the top priority element, representing the sole
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 0fa11b9c45038..e22a4b941b30c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -105,11 +105,15 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
case p: LogicalPlan if p.isStreaming => (plan, false)
case m: MultiInstanceRelation =>
- deduplicateAndRenew[LogicalPlan with MultiInstanceRelation](
- existingRelations,
- m,
- _.output.map(_.exprId.id),
- node => node.newInstance().asInstanceOf[LogicalPlan with MultiInstanceRelation])
+ val planWrapper = RelationWrapper(m.getClass, m.output.map(_.exprId.id))
+ if (existingRelations.contains(planWrapper)) {
+ val newNode = m.newInstance()
+ newNode.copyTagsFrom(m)
+ (newNode, true)
+ } else {
+ existingRelations.add(planWrapper)
+ (m, false)
+ }
case p: Project =>
deduplicateAndRenew[Project](
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index dfe1bd12bb7ff..d03d8114e9976 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -384,7 +384,9 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Rand]("random", true, Some("3.0.0")),
expression[Randn]("randn"),
+ expression[RandStr]("randstr"),
expression[Stack]("stack"),
+ expression[Uniform]("uniform"),
expression[ZeroIfNull]("zeroifnull"),
CaseWhen.registryEntry,
@@ -839,6 +841,7 @@ object FunctionRegistry {
expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder),
expression[SchemaOfVariant]("schema_of_variant"),
expression[SchemaOfVariantAgg]("schema_of_variant_agg"),
+ expression[ToVariantObject]("to_variant_object"),
// cast
expression[Cast]("cast"),
@@ -1157,6 +1160,7 @@ object TableFunctionRegistry {
generator[PosExplode]("posexplode"),
generator[PosExplode]("posexplode_outer", outer = true),
generator[Stack]("stack"),
+ generator[Collations]("collations"),
generator[SQLKeywords]("sql_keywords"),
generator[VariantExplode]("variant_explode"),
generator[VariantExplode]("variant_explode_outer", outer = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 08c5b3531b4c8..5983346ff1e27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation}
@@ -202,6 +203,20 @@ abstract class TypeCoercionBase {
}
}
+ /**
+ * A type coercion rule that implicitly casts procedure arguments to expected types.
+ */
+ object ProcedureArgumentCoercion extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) if c.resolved =>
+ val expectedDataTypes = procedure.parameters.map(_.dataType)
+ val coercedArgs = args.zip(expectedDataTypes).map {
+ case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg)
+ }
+ c.copy(args = coercedArgs)
+ }
+ }
+
/**
* Widens the data types of the [[Unpivot]] values.
*/
@@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
+ ProcedureArgumentCoercion ::
new CombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index c0689eb121679..daab9e4d78bf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -67,9 +67,13 @@ package object analysis {
}
def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = {
+ dataTypeMismatch(toSQLExpr(expr), mismatch)
+ }
+
+ def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = {
throw new AnalysisException(
errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}",
- messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)),
+ messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr),
origin = t.origin)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 60d979e9c7afb..6f445b1e88d70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -974,3 +974,13 @@ case object UnresolvedWithinGroup extends LeafExpression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException("dataType")
override lazy val resolved = false
}
+
+case class UnresolvedTranspose(
+ indices: Seq[Expression],
+ child: LogicalPlan
+) extends UnresolvedUnaryNode {
+ final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_TRANSPOSE)
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedTranspose =
+ copy(child = newChild)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
index ecdf40e87a894..dee78b8f03af4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
@@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
+import org.apache.spark.sql.connector.catalog.procedures.Procedure
import org.apache.spark.sql.types.{DataType, StructField}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
@@ -135,6 +136,12 @@ case class UnresolvedFunctionName(
case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false)
extends UnresolvedLeafNode
+/**
+ * A procedure identifier that should be resolved into [[ResolvedProcedure]].
+ */
+case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode {
+ final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE)
+}
/**
* A resolved leaf node whose statistics has no meaning.
@@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel
case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition
+case class ResolvedProcedure(
+ catalog: ProcedureCatalog,
+ ident: Identifier,
+ procedure: Procedure) extends LeafNodeWithoutStats {
+ override def output: Seq[Attribute] = Nil
+}
/**
* A plan containing resolved persistent views.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
index 2c27da3cf6e15..5444ab6845867 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
@@ -21,12 +21,12 @@ import java.util.Locale
import scala.util.control.Exception.allCatch
+import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
-import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -138,7 +138,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
case BooleanType => tryParseBoolean(field)
case StringType => StringType
case other: DataType =>
- throw QueryExecutionErrors.dataTypeUnexpectedError(other)
+ throw SparkException.internalError(s"Unexpected data type $other")
}
compatibleType(typeSoFar, typeElemInfer).getOrElse(StringType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0b5ce65fed6df..d7d53230470d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{DeserializerBuildHelper, InternalRow, Java
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts}
import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -54,9 +54,9 @@ object ExpressionEncoder {
def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = {
new ExpressionEncoder[T](
+ enc,
SerializerBuildHelper.createSerializer(enc),
- DeserializerBuildHelper.createDeserializer(enc),
- enc.clsTag)
+ DeserializerBuildHelper.createDeserializer(enc))
}
def apply(schema: StructType): ExpressionEncoder[Row] = apply(schema, lenient = false)
@@ -70,107 +70,6 @@ object ExpressionEncoder {
apply(JavaTypeInference.encoderFor(beanClass))
}
- /**
- * Given a set of N encoders, constructs a new encoder that produce objects as items in an
- * N-tuple. Note that these encoders should be unresolved so that information about
- * name/positional binding is preserved.
- * When `useNullSafeDeserializer` is true, the deserialization result for a child will be null if
- * the input is null. It is false by default as most deserializers handle null input properly and
- * don't require an extra null check. Some of them are null-tolerant, such as the deserializer for
- * `Option[T]`, and we must not set it to true in this case.
- */
- def tuple(
- encoders: Seq[ExpressionEncoder[_]],
- useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = {
- if (encoders.length > 22) {
- throw QueryExecutionErrors.elementsOfTupleExceedLimitError()
- }
-
- encoders.foreach(_.assertUnresolved())
-
- val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
-
- val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true)
- val serializers = encoders.zipWithIndex.map { case (enc, index) =>
- val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct
- assert(boundRefs.size == 1, "object serializer should have only one bound reference but " +
- s"there are ${boundRefs.size}")
-
- val originalInputObject = boundRefs.head
- val newInputObject = Invoke(
- newSerializerInput,
- s"_${index + 1}",
- originalInputObject.dataType,
- returnNullable = originalInputObject.nullable)
-
- val newSerializer = enc.objSerializer.transformUp {
- case BoundReference(0, _, _) => newInputObject
- }
-
- Alias(newSerializer, s"_${index + 1}")()
- }
- val newSerializer = CreateStruct(serializers)
-
- def nullSafe(input: Expression, result: Expression): Expression = {
- If(IsNull(input), Literal.create(null, result.dataType), result)
- }
-
- val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType)
- val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
- val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct
- assert(getColExprs.size == 1, "object deserializer should have only one " +
- s"`GetColumnByOrdinal`, but there are ${getColExprs.size}")
-
- val input = GetStructField(newDeserializerInput, index)
- val childDeserializer = enc.objDeserializer.transformUp {
- case GetColumnByOrdinal(0, _) => input
- }
-
- if (useNullSafeDeserializer && enc.objSerializer.nullable) {
- nullSafe(input, childDeserializer)
- } else {
- childDeserializer
- }
- }
- val newDeserializer =
- NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)
-
- new ExpressionEncoder[Any](
- nullSafe(newSerializerInput, newSerializer),
- nullSafe(newDeserializerInput, newDeserializer),
- ClassTag(cls))
- }
-
- // Tuple1
- def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
- tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
-
- def tuple[T1, T2](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
- tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
-
- def tuple[T1, T2, T3](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2],
- e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
- tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
-
- def tuple[T1, T2, T3, T4](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2],
- e3: ExpressionEncoder[T3],
- e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
- tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
-
- def tuple[T1, T2, T3, T4, T5](
- e1: ExpressionEncoder[T1],
- e2: ExpressionEncoder[T2],
- e3: ExpressionEncoder[T3],
- e4: ExpressionEncoder[T4],
- e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
- tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
-
private val anyObjectType = ObjectType(classOf[Any])
/**
@@ -228,6 +127,7 @@ object ExpressionEncoder {
* A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer`
* and a `deserializer`.
*
+ * @param encoder the `AgnosticEncoder` for type `T`.
* @param objSerializer An expression that can be used to encode a raw object to corresponding
* Spark SQL representation that can be a primitive column, array, map or a
* struct. This represents how Spark SQL generally serializes an object of
@@ -236,13 +136,15 @@ object ExpressionEncoder {
* representation. This represents how Spark SQL generally deserializes
* a serialized value in Spark SQL representation back to an object of
* type `T`.
- * @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
+ encoder: AgnosticEncoder[T],
objSerializer: Expression,
- objDeserializer: Expression,
- clsTag: ClassTag[T])
- extends Encoder[T] {
+ objDeserializer: Expression)
+ extends Encoder[T]
+ with ToAgnosticEncoder[T] {
+
+ override def clsTag: ClassTag[T] = encoder.clsTag
/**
* A sequence of expressions, one for each top-level field that can be used to
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala
similarity index 57%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala
index f3e0c0aca29ca..49c7b41f77472 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CannotReplaceMissingTableException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/KryoSerializationCodecImpl.scala
@@ -14,20 +14,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.sql.catalyst.encoders
+import java.nio.ByteBuffer
-package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.expressions.objects.SerializerSupport
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.util.quoteNameParts
-import org.apache.spark.sql.connector.catalog.Identifier
-import org.apache.spark.util.ArrayImplicits._
+/**
+ * A codec that uses Kryo to (de)serialize arbitrary objects to and from a byte array.
+ */
+class KryoSerializationCodecImpl extends Codec[Any, Array[Byte]] {
+ private val serializer = SerializerSupport.newSerializer(useKryo = true)
+ override def encode(in: Any): Array[Byte] =
+ serializer.serialize(in).array()
-class CannotReplaceMissingTableException(
- tableIdentifier: Identifier,
- cause: Option[Throwable] = None)
- extends AnalysisException(
- errorClass = "TABLE_OR_VIEW_NOT_FOUND",
- messageParameters = Map("relationName"
- -> quoteNameParts((tableIdentifier.namespace :+ tableIdentifier.name).toImmutableArraySeq)),
- cause = cause)
+ override def decode(out: Array[Byte]): Any =
+ serializer.deserialize(ByteBuffer.wrap(out))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala
index de1122da646b7..c226e48c6be5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.internal.SQLConf
// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "Usage: input [NOT] BETWEEN lower AND upper - evaluate if `input` is [not] in between `lower` and `upper`",
+ usage = "input [NOT] _FUNC_ lower AND upper - evaluate if `input` is [not] in between `lower` and `upper`",
examples = """
Examples:
> SELECT 0.5 _FUNC_ 0.1 AND 1.0;
@@ -33,7 +33,7 @@ import org.apache.spark.sql.internal.SQLConf
* lower - Lower bound of the between check.
* upper - Upper bound of the between check.
""",
- since = "4.0.0",
+ since = "1.0.0",
group = "conditional_funcs")
case class Between private(input: Expression, lower: Expression, upper: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 4a2b4b28e690e..7a2799e99fe2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -128,7 +128,10 @@ object Cast extends QueryErrorsBase {
case (TimestampType, _: NumericType) => true
case (VariantType, _) => variant.VariantGet.checkDataType(to)
- case (_, VariantType) => variant.VariantGet.checkDataType(from)
+ // Structs and Maps can't be cast to Variants since the Variant spec does not yet contain
+ // lossless equivalents for these types. The `to_variant_object` expression can be used instead
+ // to convert data of these types to Variant Objects.
+ case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false)
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canAnsiCast(fromType, toType) && resolvableNullability(fn, tn)
@@ -237,7 +240,10 @@ object Cast extends QueryErrorsBase {
case (_: NumericType, _: NumericType) => true
case (VariantType, _) => variant.VariantGet.checkDataType(to)
- case (_, VariantType) => variant.VariantGet.checkDataType(from)
+ // Structs and Maps can't be cast to Variants since the Variant spec does not yet contain
+ // lossless equivalents for these types. The `to_variant_object` expression can be used instead
+ // to convert data of these types to Variant Objects.
+ case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false)
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canCast(fromType, toType) &&
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala
new file mode 100644
index 0000000000000..0b5479cc8f0ee
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PipeSelect.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
+import org.apache.spark.sql.catalyst.trees.TreePattern.{PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE, TreePattern}
+import org.apache.spark.sql.errors.QueryCompilationErrors
+
+/**
+ * Represents a SELECT clause when used with the |> SQL pipe operator.
+ * We use this to make sure that no aggregate functions exist in the SELECT expressions.
+ */
+case class PipeSelect(child: Expression)
+ extends UnaryExpression with RuntimeReplaceable {
+ final override val nodePatterns: Seq[TreePattern] = Seq(PIPE_OPERATOR_SELECT, RUNTIME_REPLACEABLE)
+ override def withNewChildInternal(newChild: Expression): Expression = PipeSelect(newChild)
+ override lazy val replacement: Expression = {
+ def visit(e: Expression): Unit = e match {
+ case a: AggregateFunction =>
+ // If we used the pipe operator |> SELECT clause to specify an aggregate function, this is
+ // invalid; return an error message instructing the user to use the pipe operator
+ // |> AGGREGATE clause for this purpose instead.
+ throw QueryCompilationErrors.pipeOperatorSelectContainsAggregateFunction(a)
+ case _: WindowExpression =>
+ // Window functions are allowed in pipe SELECT operators, so do not traverse into children.
+ case _ =>
+ e.children.foreach(visit)
+ }
+ visit(child)
+ child
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 4987e31b49911..8ad062ab0e2f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
import com.google.common.primitives.{Doubles, Ints, Longs}
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
@@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.types.PhysicalNumericType
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
-import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
@@ -189,7 +189,7 @@ case class ApproximatePercentile(
PhysicalNumericType.numeric(n)
.toDouble(value.asInstanceOf[PhysicalNumericType#InternalType])
case other: DataType =>
- throw QueryExecutionErrors.dataTypeUnexpectedError(other)
+ throw SparkException.internalError(s"Unexpected data type $other")
}
buffer.add(doubleValue)
}
@@ -214,7 +214,7 @@ case class ApproximatePercentile(
case DoubleType => doubleResult
case _: DecimalType => doubleResult.map(Decimal(_))
case other: DataType =>
- throw QueryExecutionErrors.dataTypeUnexpectedError(other)
+ throw SparkException.internalError(s"Unexpected data type $other")
}
if (result.length == 0) {
null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
index ba26c5a1022d0..eda2c742ab4b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala
@@ -143,7 +143,8 @@ case class HistogramNumeric(
if (buffer.getUsedBins < 1) {
null
} else {
- val result = (0 until buffer.getUsedBins).map { index =>
+ val array = new Array[AnyRef](buffer.getUsedBins)
+ (0 until buffer.getUsedBins).foreach { index =>
// Note that the 'coord.x' and 'coord.y' have double-precision floating point type here.
val coord = buffer.getBin(index)
if (propagateInputType) {
@@ -163,16 +164,16 @@ case class HistogramNumeric(
coord.x.toLong
case _ => coord.x
}
- InternalRow.apply(result, coord.y)
+ array(index) = InternalRow.apply(result, coord.y)
} else {
// Otherwise, just apply the double-precision values in 'coord.x' and 'coord.y' to the
// output row directly. In this case: 'SELECT histogram_numeric(val, 3)
// FROM VALUES (0L), (1L), (2L), (10L) AS tab(col)' returns an array of structs where the
// first field has DoubleType.
- InternalRow.apply(coord.x, coord.y)
+ array(index) = InternalRow.apply(coord.x, coord.y)
}
}
- new GenericArrayData(result)
+ new GenericArrayData(array)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index e77622a26d90a..c593c8bfb8341 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -170,7 +170,7 @@ case class CollectSet(
override def eval(buffer: mutable.HashSet[Any]): Any = {
val array = child.dataType match {
case BinaryType =>
- buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray()).toArray
+ buffer.iterator.map(_.asInstanceOf[ArrayData].toByteArray()).toArray[Any]
case _ => buffer.toArray
}
new GenericArrayData(array)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 375a2bde59230..5cdd3c7eb62d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -433,7 +433,7 @@ case class ArraysZip(children: Seq[Expression], names: Seq[Expression])
inputArrays.map(_.numElements()).max
}
- val result = new Array[InternalRow](biggestCardinality)
+ val result = new Array[AnyRef](biggestCardinality)
val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
for (i <- 0 until biggestCardinality) {
@@ -1058,20 +1058,26 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
- ascendingOrder match {
- case Literal(_: Boolean, BooleanType) =>
- TypeCheckResult.TypeCheckSuccess
- case _ =>
- DataTypeMismatch(
- errorSubClass = "UNEXPECTED_INPUT_TYPE",
- messageParameters = Map(
- "paramIndex" -> ordinalNumber(1),
- "requiredType" -> toSQLType(BooleanType),
- "inputSql" -> toSQLExpr(ascendingOrder),
- "inputType" -> toSQLType(ascendingOrder.dataType))
- )
+ if (!ascendingOrder.foldable) {
+ DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("ascendingOrder"),
+ "inputType" -> toSQLType(ascendingOrder.dataType),
+ "inputExpr" -> toSQLExpr(ascendingOrder)))
+ } else if (ascendingOrder.dataType != BooleanType) {
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(1),
+ "requiredType" -> toSQLType(BooleanType),
+ "inputSql" -> toSQLExpr(ascendingOrder),
+ "inputType" -> toSQLType(ascendingOrder.dataType))
+ )
+ } else {
+ TypeCheckResult.TypeCheckSuccess
}
- case ArrayType(dt, _) =>
+ case ArrayType(_, _) =>
DataTypeMismatch(
errorSubClass = "INVALID_ORDERING_TYPE",
messageParameters = Map(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index ba1beab28d9a7..b8b47f2763f5b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.types.StringTypeAnyCollation
+import org.apache.spark.sql.internal.types.StringTypeNonCSAICollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
@@ -579,7 +579,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E
override def third: Expression = keyValueDelim
override def inputTypes: Seq[AbstractDataType] =
- Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation)
override def dataType: DataType = MapType(first.dataType, first.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index ca74fefb9c032..56ecbf550e45c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TernaryLike
@@ -188,6 +189,13 @@ case class CaseWhen(
}
override def nullable: Boolean = {
+ if (branches.exists(_._1 == TrueLiteral)) {
+ // if any of the branch is always true
+ // nullability check should only be related to branches
+ // before the TrueLiteral and value of the first TrueLiteral branch
+ val (h, t) = branches.span(_._1 != TrueLiteral)
+ return h.exists(_._2.nullable) || t.head._2.nullable
+ }
// Result is nullable if any of the branch is nullable, or if the else value is nullable
branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 2cc88a25f465d..dc58352a1b362 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable
+import scala.jdk.CollectionConverters.CollectionHasAsScala
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -28,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter}
import org.apache.spark.sql.catalyst.trees.TreePattern.{GENERATOR, TreePattern}
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData}
import org.apache.spark.sql.catalyst.util.SQLKeywordUtils._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
@@ -618,3 +619,44 @@ case class SQLKeywords() extends LeafExpression with Generator with CodegenFallb
override def prettyName: String = "sql_keywords"
}
+
+@ExpressionDescription(
+ usage = """_FUNC_() - Get all of the Spark SQL string collations""",
+ examples = """
+ Examples:
+ > SELECT * FROM _FUNC_() WHERE NAME = 'UTF8_BINARY';
+ SYSTEM BUILTIN UTF8_BINARY NULL NULL ACCENT_SENSITIVE CASE_SENSITIVE NO_PAD NULL
+ """,
+ since = "4.0.0",
+ group = "generator_funcs")
+case class Collations() extends LeafExpression with Generator with CodegenFallback {
+ override def elementSchema: StructType = new StructType()
+ .add("CATALOG", StringType, nullable = false)
+ .add("SCHEMA", StringType, nullable = false)
+ .add("NAME", StringType, nullable = false)
+ .add("LANGUAGE", StringType)
+ .add("COUNTRY", StringType)
+ .add("ACCENT_SENSITIVITY", StringType, nullable = false)
+ .add("CASE_SENSITIVITY", StringType, nullable = false)
+ .add("PAD_ATTRIBUTE", StringType, nullable = false)
+ .add("ICU_VERSION", StringType)
+
+ override def eval(input: InternalRow): IterableOnce[InternalRow] = {
+ CollationFactory.listCollations().asScala.map(CollationFactory.loadCollationMeta).map { m =>
+ InternalRow(
+ UTF8String.fromString(m.catalog),
+ UTF8String.fromString(m.schema),
+ UTF8String.fromString(m.collationName),
+ UTF8String.fromString(m.language),
+ UTF8String.fromString(m.country),
+ UTF8String.fromString(
+ if (m.accentSensitivity) "ACCENT_SENSITIVE" else "ACCENT_INSENSITIVE"),
+ UTF8String.fromString(
+ if (m.caseSensitivity) "CASE_SENSITIVE" else "CASE_INSENSITIVE"),
+ UTF8String.fromString(m.padAttribute),
+ UTF8String.fromString(m.icuVersion))
+ }
+ }
+
+ override def prettyName: String = "collations"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 7005d663a3f96..2037eb22fede6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -632,7 +632,8 @@ case class JsonToStructs(
schema: DataType,
options: Map[String, String],
child: Expression,
- timeZoneId: Option[String] = None)
+ timeZoneId: Option[String] = None,
+ variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS))
extends UnaryExpression
with TimeZoneAwareExpression
with CodegenFallback
@@ -719,7 +720,8 @@ case class JsonToStructs(
override def nullSafeEval(json: Any): Any = nullableSchema match {
case _: VariantType =>
- VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String])
+ VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String],
+ allowDuplicateKeys = variantAllowDuplicateKeys)
case _ =>
converter(parser.parse(json.asInstanceOf[UTF8String]))
}
@@ -737,6 +739,12 @@ case class JsonToStructs(
copy(child = newChild)
}
+object JsonToStructs {
+ def unapply(
+ j: JsonToStructs): Option[(DataType, Map[String, String], Expression, Option[String])] =
+ Some((j.schema, j.options, j.child, j.timeZoneId))
+}
+
/**
* Converts a [[StructType]], [[ArrayType]] or [[MapType]] to a JSON output string.
*/
@@ -1072,7 +1080,7 @@ case class JsonObjectKeys(child: Expression) extends UnaryExpression with Codege
// skip all the children of inner object or array
parser.skipChildren()
}
- new GenericArrayData(arrayBufferOfKeys.toArray)
+ new GenericArrayData(arrayBufferOfKeys.toArray[Any])
}
override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 00274a16b888b..ddba820414ae4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1293,7 +1293,7 @@ sealed trait BitShiftOperation
* @param right number of bits to left shift.
*/
@ExpressionDescription(
- usage = "base << exp - Bitwise left shift.",
+ usage = "base _FUNC_ exp - Bitwise left shift.",
examples = """
Examples:
> SELECT shiftleft(2, 1);
@@ -1322,7 +1322,7 @@ case class ShiftLeft(left: Expression, right: Expression) extends BitShiftOperat
* @param right number of bits to right shift.
*/
@ExpressionDescription(
- usage = "base >> expr - Bitwise (signed) right shift.",
+ usage = "base _FUNC_ expr - Bitwise (signed) right shift.",
examples = """
Examples:
> SELECT shiftright(4, 1);
@@ -1350,7 +1350,7 @@ case class ShiftRight(left: Expression, right: Expression) extends BitShiftOpera
* @param right the number of bits to right shift.
*/
@ExpressionDescription(
- usage = "base >>> expr - Bitwise unsigned right shift.",
+ usage = "base _FUNC_ expr - Bitwise unsigned right shift.",
examples = """
Examples:
> SELECT shiftrightunsigned(4, 1);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index f5db972a28643..f329f8346b0de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -17,13 +17,18 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern}
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -33,8 +38,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
-abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic
- with ExpressionWithRandomSeed {
+trait RDG extends Expression with ExpressionWithRandomSeed {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
@@ -43,12 +47,6 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm
override def stateful: Boolean = true
- override protected def initializeInternal(partitionIndex: Int): Unit = {
- rng = new XORShiftRandom(seed + partitionIndex)
- }
-
- override def seedExpression: Expression = child
-
@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
@@ -57,6 +55,15 @@ abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterm
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
+}
+
+abstract class NondeterministicUnaryRDG
+ extends RDG with UnaryLike[Expression] with Nondeterministic with ExpectsInputTypes {
+ override def seedExpression: Expression = child
+
+ override protected def initializeInternal(partitionIndex: Int): Unit = {
+ rng = new XORShiftRandom(seed + partitionIndex)
+ }
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
}
@@ -74,6 +81,7 @@ trait ExpressionWithRandomSeed extends Expression {
private[catalyst] object ExpressionWithRandomSeed {
def expressionToSeed(e: Expression, source: String): Option[Long] = e match {
+ case IntegerLiteral(seed) => Some(seed)
case LongLiteral(seed) => Some(seed)
case Literal(null, _) => None
case _ => throw QueryCompilationErrors.invalidRandomSeedParameter(source, e)
@@ -99,7 +107,7 @@ private[catalyst] object ExpressionWithRandomSeed {
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
-case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG {
+case class Rand(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG {
def this() = this(UnresolvedSeed, true)
@@ -150,7 +158,7 @@ object Rand {
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
-case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
+case class Randn(child: Expression, hideSeed: Boolean = false) extends NondeterministicUnaryRDG {
def this() = this(UnresolvedSeed, true)
@@ -181,3 +189,236 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
object Randn {
def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(min, max[, seed]) - Returns a random value with independent and identically
+ distributed (i.i.d.) values with the specified range of numbers. The random seed is optional.
+ The provided numbers specifying the minimum and maximum values of the range must be constant.
+ If both of these numbers are integers, then the result will also be an integer. Otherwise if
+ one or both of these are floating-point numbers, then the result will also be a floating-point
+ number.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(10, 20, 0) > 0 AS result;
+ true
+ """,
+ since = "4.0.0",
+ group = "math_funcs")
+case class Uniform(min: Expression, max: Expression, seedExpression: Expression)
+ extends RuntimeReplaceable with TernaryLike[Expression] with RDG {
+ def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed)
+
+ final override lazy val deterministic: Boolean = false
+ override val nodePatterns: Seq[TreePattern] =
+ Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED)
+
+ override val dataType: DataType = {
+ val first = min.dataType
+ val second = max.dataType
+ (min.dataType, max.dataType) match {
+ case _ if !seedExpression.resolved || seedExpression.dataType == NullType =>
+ NullType
+ case (_, NullType) | (NullType, _) => NullType
+ case (_, LongType) | (LongType, _)
+ if Seq(first, second).forall(integer) => LongType
+ case (_, IntegerType) | (IntegerType, _)
+ if Seq(first, second).forall(integer) => IntegerType
+ case (_, ShortType) | (ShortType, _)
+ if Seq(first, second).forall(integer) => ShortType
+ case (_, DoubleType) | (DoubleType, _) => DoubleType
+ case (_, FloatType) | (FloatType, _) => FloatType
+ case _ =>
+ throw SparkException.internalError(
+ s"Unexpected argument data types: ${min.dataType}, ${max.dataType}")
+ }
+ }
+
+ private def integer(t: DataType): Boolean = t match {
+ case _: ShortType | _: IntegerType | _: LongType => true
+ case _ => false
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
+ def requiredType = "integer or floating-point"
+ Seq((min, "min", 0),
+ (max, "max", 1),
+ (seedExpression, "seed", 2)).foreach {
+ case (expr: Expression, name: String, index: Int) =>
+ if (result == TypeCheckResult.TypeCheckSuccess) {
+ if (!expr.foldable) {
+ result = DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> name,
+ "inputType" -> requiredType,
+ "inputExpr" -> toSQLExpr(expr)))
+ } else expr.dataType match {
+ case _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType |
+ _: NullType =>
+ case _ =>
+ result = DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(index),
+ "requiredType" -> requiredType,
+ "inputSql" -> toSQLExpr(expr),
+ "inputType" -> toSQLType(expr.dataType)))
+ }
+ }
+ }
+ result
+ }
+
+ override def first: Expression = min
+ override def second: Expression = max
+ override def third: Expression = seedExpression
+
+ override def withNewSeed(newSeed: Long): Expression =
+ Uniform(min, max, Literal(newSeed, LongType))
+
+ override def withNewChildrenInternal(
+ newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
+ Uniform(newFirst, newSecond, newThird)
+
+ override def replacement: Expression = {
+ if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) {
+ Literal(null)
+ } else {
+ def cast(e: Expression, to: DataType): Expression = if (e.dataType == to) e else Cast(e, to)
+ cast(Add(
+ cast(min, DoubleType),
+ Multiply(
+ Subtract(
+ cast(max, DoubleType),
+ cast(min, DoubleType)),
+ Rand(seed))),
+ dataType)
+ }
+ }
+}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen
+ uniformly at random from the following pool of characters: 0-9, a-z, A-Z. The random seed is
+ optional. The string length must be a constant two-byte or four-byte integer (SMALLINT or INT,
+ respectively).
+ """,
+ examples =
+ """
+ Examples:
+ > SELECT _FUNC_(3, 0) AS result;
+ ceV
+ """,
+ since = "4.0.0",
+ group = "string_funcs")
+case class RandStr(length: Expression, override val seedExpression: Expression)
+ extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic {
+ def this(length: Expression) = this(length, UnresolvedSeed)
+
+ override def nullable: Boolean = false
+ override def dataType: DataType = StringType
+ override def stateful: Boolean = true
+ override def left: Expression = length
+ override def right: Expression = seedExpression
+
+ /**
+ * Record ID within each partition. By being transient, the Random Number Generator is
+ * reset every time we serialize and deserialize and initialize it.
+ */
+ @transient protected var rng: XORShiftRandom = _
+
+ @transient protected lazy val seed: Long = seedExpression match {
+ case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
+ case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
+ }
+ override protected def initializeInternal(partitionIndex: Int): Unit = {
+ rng = new XORShiftRandom(seed + partitionIndex)
+ }
+
+ override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType))
+ override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
+ RandStr(newFirst, newSecond)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess
+ def requiredType = "INT or SMALLINT"
+ Seq((length, "length", 0),
+ (seedExpression, "seedExpression", 1)).foreach {
+ case (expr: Expression, name: String, index: Int) =>
+ if (result == TypeCheckResult.TypeCheckSuccess) {
+ if (!expr.foldable) {
+ result = DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> name,
+ "inputType" -> requiredType,
+ "inputExpr" -> toSQLExpr(expr)))
+ } else expr.dataType match {
+ case _: ShortType | _: IntegerType =>
+ case _: LongType if index == 1 =>
+ case _ =>
+ result = DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(index),
+ "requiredType" -> requiredType,
+ "inputSql" -> toSQLExpr(expr),
+ "inputType" -> toSQLType(expr.dataType)))
+ }
+ }
+ }
+ result
+ }
+
+ override def evalInternal(input: InternalRow): Any = {
+ val numChars = length.eval(input).asInstanceOf[Number].intValue()
+ val bytes = new Array[Byte](numChars)
+ (0 until numChars).foreach { i =>
+ // We generate a random number between 0 and 61, inclusive. Between the 62 different choices
+ // we choose 0-9, a-z, or A-Z, where each category comprises 10 choices, 26 choices, or 26
+ // choices, respectively (10 + 26 + 26 = 62).
+ val num = (rng.nextInt() % 62).abs
+ num match {
+ case _ if num < 10 =>
+ bytes.update(i, ('0' + num).toByte)
+ case _ if num < 36 =>
+ bytes.update(i, ('a' + num - 10).toByte)
+ case _ =>
+ bytes.update(i, ('A' + num - 36).toByte)
+ }
+ }
+ val result: UTF8String = UTF8String.fromBytes(bytes.toArray)
+ result
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val className = classOf[XORShiftRandom].getName
+ val rngTerm = ctx.addMutableState(className, "rng")
+ ctx.addPartitionInitializationStatement(
+ s"$rngTerm = new $className(${seed}L + partitionIndex);")
+ val eval = length.genCode(ctx)
+ ev.copy(code =
+ code"""
+ |${eval.code}
+ |int length = (int)(${eval.value});
+ |char[] chars = new char[length];
+ |for (int i = 0; i < length; i++) {
+ | int v = Math.abs($rngTerm.nextInt() % 62);
+ | if (v < 10) {
+ | chars[i] = (char)('0' + v);
+ | } else if (v < 36) {
+ | chars[i] = (char)('a' + (v - 10));
+ | } else {
+ | chars[i] = (char)('A' + (v - 36));
+ | }
+ |}
+ |UTF8String ${ev.value} = UTF8String.fromString(new String(chars));
+ |boolean ${ev.isNull} = false;
+ |""".stripMargin,
+ isNull = FalseLiteral)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 6ccd5a451eafc..da6d786efb4e3 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.nio.{ByteBuffer, CharBuffer}
import java.nio.charset.CharacterCodingException
-import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
+import java.text.{DecimalFormat, DecimalFormatSymbols}
import java.util.{Base64 => JBase64, HashMap, Locale, Map => JMap}
import scala.collection.mutable.ArrayBuffer
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO
import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, CollationFactory, CollationSupport, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation}
+import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeNonCSAICollation}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -609,6 +609,8 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate
defineCodeGen(ctx, ev, (c1, c2) =>
CollationSupport.Contains.genCode(c1, c2, collationId))
}
+ override def inputTypes : Seq[AbstractDataType] =
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation)
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight)
}
@@ -650,6 +652,10 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica
defineCodeGen(ctx, ev, (c1, c2) =>
CollationSupport.StartsWith.genCode(c1, c2, collationId))
}
+
+ override def inputTypes : Seq[AbstractDataType] =
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation)
+
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight)
}
@@ -691,6 +697,10 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate
defineCodeGen(ctx, ev, (c1, c2) =>
CollationSupport.EndsWith.genCode(c1, c2, collationId))
}
+
+ override def inputTypes : Seq[AbstractDataType] =
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation)
+
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight)
}
@@ -919,7 +929,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp
override def dataType: DataType = srcExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
- Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation)
override def first: Expression = srcExpr
override def second: Expression = searchExpr
override def third: Expression = replaceExpr
@@ -1167,7 +1177,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
override def dataType: DataType = srcExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
- Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation)
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation)
override def first: Expression = srcExpr
override def second: Expression = matchingExpr
override def third: Expression = replaceExpr
@@ -1394,6 +1404,9 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None)
override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
CollationSupport.StringTrim.exec(srcString, trimString, collationId)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation)
+
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(
srcStr = newChildren.head,
@@ -1501,6 +1514,9 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None
override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation)
+
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): StringTrimLeft =
copy(
@@ -1561,6 +1577,9 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non
override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
CollationSupport.StringTrimRight.exec(srcString, trimString, collationId)
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation)
+
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): StringTrimRight =
copy(
@@ -1595,7 +1614,7 @@ case class StringInstr(str: Expression, substr: Expression)
override def right: Expression = substr
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] =
- Seq(StringTypeAnyCollation, StringTypeAnyCollation)
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation)
override def nullSafeEval(string: Any, sub: Any): Any = {
CollationSupport.StringInstr.
@@ -1643,7 +1662,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr:
override def dataType: DataType = strExpr.dataType
override def inputTypes: Seq[AbstractDataType] =
- Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType)
override def first: Expression = strExpr
override def second: Expression = delimExpr
override def third: Expression = countExpr
@@ -1701,7 +1720,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
override def nullable: Boolean = substr.nullable || str.nullable
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] =
- Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType)
override def eval(input: InternalRow): Any = {
val s = start.eval(input)
@@ -1969,13 +1988,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
if (pattern == null) {
null
} else {
- val sb = new StringBuffer()
- val formatter = new java.util.Formatter(sb, Locale.US)
-
+ val formatter = new java.util.Formatter(Locale.US)
val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef])
- formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*)
-
- UTF8String.fromString(sb.toString)
+ UTF8String.fromString(
+ formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*).toString)
}
}
@@ -2006,19 +2022,16 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
- val sb = ctx.freshName("sb")
- val stringBuffer = classOf[StringBuffer].getName
ev.copy(code = code"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
- $stringBuffer $sb = new $stringBuffer();
- $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
+ $formatter $form = new $formatter(${classOf[Locale].getName}.US);
Object[] $argList = new Object[$numArgLists];
$argListCodes
- $form.format(${pattern.value}.toString(), $argList);
- ${ev.value} = UTF8String.fromString($sb.toString());
+ ${ev.value} = UTF8String.fromString(
+ $form.format(${pattern.value}.toString(), $argList).toString());
}""")
}
@@ -3333,14 +3346,37 @@ case class FormatNumber(x: Expression, d: Expression)
/**
* Splits a string into arrays of sentences, where each sentence is an array of words.
- * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used.
+ * The `lang` and `country` arguments are optional, their default values are all '',
+ * - When they are omitted:
+ * 1. If they are both omitted, the `Locale.ROOT - locale(language='', country='')` is used.
+ * The `Locale.ROOT` is regarded as the base locale of all locales, and is used as the
+ * language/country neutral locale for the locale sensitive operations.
+ * 2. If the `country` is omitted, the `locale(language, country='')` is used.
+ * - When they are null:
+ * 1. If they are both `null`, the `Locale.US - locale(language='en', country='US')` is used.
+ * 2. If the `language` is null and the `country` is not null,
+ * the `Locale.US - locale(language='en', country='US')` is used.
+ * 3. If the `language` is not null and the `country` is null, the `locale(language)` is used.
+ * 4. If neither is `null`, the `locale(language, country)` is used.
*/
@ExpressionDescription(
- usage = "_FUNC_(str[, lang, country]) - Splits `str` into an array of array of words.",
+ usage = "_FUNC_(str[, lang[, country]]) - Splits `str` into an array of array of words.",
+ arguments = """
+ Arguments:
+ * str - A STRING expression to be parsed.
+ * lang - An optional STRING expression with a language code from ISO 639 Alpha-2 (e.g. 'DE'),
+ Alpha-3, or a language subtag of up to 8 characters.
+ * country - An optional STRING expression with a country code from ISO 3166 alpha-2 country
+ code or a UN M.49 numeric-3 area code.
+ """,
examples = """
Examples:
> SELECT _FUNC_('Hi there! Good morning.');
[["Hi","there"],["Good","morning"]]
+ > SELECT _FUNC_('Hi there! Good morning.', 'en');
+ [["Hi","there"],["Good","morning"]]
+ > SELECT _FUNC_('Hi there! Good morning.', 'en', 'US');
+ [["Hi","there"],["Good","morning"]]
""",
since = "2.0.0",
group = "string_funcs")
@@ -3348,7 +3384,9 @@ case class Sentences(
str: Expression,
language: Expression = Literal(""),
country: Expression = Literal(""))
- extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback {
+ extends TernaryExpression
+ with ImplicitCastInputTypes
+ with RuntimeReplaceable {
def this(str: Expression) = this(str, Literal(""), Literal(""))
def this(str: Expression, language: Expression) = this(str, language, Literal(""))
@@ -3362,49 +3400,18 @@ case class Sentences(
override def second: Expression = language
override def third: Expression = country
- override def eval(input: InternalRow): Any = {
- val string = str.eval(input)
- if (string == null) {
- null
- } else {
- val languageStr = language.eval(input).asInstanceOf[UTF8String]
- val countryStr = country.eval(input).asInstanceOf[UTF8String]
- val locale = if (languageStr != null && countryStr != null) {
- new Locale(languageStr.toString, countryStr.toString)
- } else {
- Locale.US
- }
- getSentences(string.asInstanceOf[UTF8String].toString, locale)
- }
- }
-
- private def getSentences(sentences: String, locale: Locale) = {
- val bi = BreakIterator.getSentenceInstance(locale)
- bi.setText(sentences)
- var idx = 0
- val result = new ArrayBuffer[GenericArrayData]
- while (bi.next != BreakIterator.DONE) {
- val sentence = sentences.substring(idx, bi.current)
- idx = bi.current
-
- val wi = BreakIterator.getWordInstance(locale)
- var widx = 0
- wi.setText(sentence)
- val words = new ArrayBuffer[UTF8String]
- while (wi.next != BreakIterator.DONE) {
- val word = sentence.substring(widx, wi.current)
- widx = wi.current
- if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word)
- }
- result += new GenericArrayData(words)
- }
- new GenericArrayData(result)
- }
+ override def replacement: Expression =
+ StaticInvoke(
+ classOf[ExpressionImplUtils],
+ dataType,
+ "getSentences",
+ Seq(str, language, country),
+ inputTypes,
+ propagateNull = false)
override protected def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Sentences =
copy(str = newFirst, language = newSecond, country = newThird)
-
}
/**
@@ -3475,7 +3482,7 @@ case class SplitPart (
false)
override def nodeName: String = "split_part"
override def inputTypes: Seq[AbstractDataType] =
- Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
+ Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType)
def children: Seq[Expression] = Seq(str, delimiter, partNum)
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(str = newChildren.apply(0), delimiter = newChildren.apply(1),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
index a34a2c740876e..ad9610ea0c78a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromProtobufSqlFunctions.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{MapType, NullType, StringType}
+import org.apache.spark.sql.types.{BinaryType, MapType, NullType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -88,6 +88,27 @@ case class FromProtobuf(
messageName: Expression,
descFilePath: Expression,
options: Expression) extends QuaternaryExpression with RuntimeReplaceable {
+
+ def this(data: Expression, messageName: Expression, descFilePathOrOptions: Expression) = {
+ this(
+ data,
+ messageName,
+ descFilePathOrOptions match {
+ case lit: Literal
+ if lit.dataType == StringType || lit.dataType == BinaryType => descFilePathOrOptions
+ case _ => Literal(null)
+ },
+ descFilePathOrOptions.dataType match {
+ case _: MapType => descFilePathOrOptions
+ case _ => Literal(null)
+ }
+ )
+ }
+
+ def this(data: Expression, messageName: Expression) = {
+ this(data, messageName, Literal(null), Literal(null))
+ }
+
override def first: Expression = data
override def second: Expression = messageName
override def third: Expression = descFilePath
@@ -110,11 +131,11 @@ case class FromProtobuf(
"representing the Protobuf message name"))
}
val descFilePathCheck = descFilePath.dataType match {
- case _: StringType if descFilePath.foldable => None
+ case _: StringType | BinaryType | NullType if descFilePath.foldable => None
case _ =>
Some(TypeCheckResult.TypeCheckFailure(
"The third argument of the FROM_PROTOBUF SQL function must be a constant string " +
- "representing the Protobuf descriptor file path"))
+ "or binary data representing the Protobuf descriptor file path"))
}
val optionsCheck = options.dataType match {
case MapType(StringType, StringType, _) |
@@ -141,7 +162,10 @@ case class FromProtobuf(
s.toString
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
+ case s: UTF8String if s.toString.isEmpty => None
case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString))
+ case bytes: Array[Byte] if bytes.isEmpty => None
+ case bytes: Array[Byte] => Some(bytes)
case null => None
}
val optionsValue: Map[String, String] = options.eval() match {
@@ -201,6 +225,27 @@ case class ToProtobuf(
messageName: Expression,
descFilePath: Expression,
options: Expression) extends QuaternaryExpression with RuntimeReplaceable {
+
+ def this(data: Expression, messageName: Expression, descFilePathOrOptions: Expression) = {
+ this(
+ data,
+ messageName,
+ descFilePathOrOptions match {
+ case lit: Literal
+ if lit.dataType == StringType || lit.dataType == BinaryType => descFilePathOrOptions
+ case _ => Literal(null)
+ },
+ descFilePathOrOptions.dataType match {
+ case _: MapType => descFilePathOrOptions
+ case _ => Literal(null)
+ }
+ )
+ }
+
+ def this(data: Expression, messageName: Expression) = {
+ this(data, messageName, Literal(null), Literal(null))
+ }
+
override def first: Expression = data
override def second: Expression = messageName
override def third: Expression = descFilePath
@@ -223,11 +268,11 @@ case class ToProtobuf(
"representing the Protobuf message name"))
}
val descFilePathCheck = descFilePath.dataType match {
- case _: StringType if descFilePath.foldable => None
+ case _: StringType | BinaryType | NullType if descFilePath.foldable => None
case _ =>
Some(TypeCheckResult.TypeCheckFailure(
"The third argument of the TO_PROTOBUF SQL function must be a constant string " +
- "representing the Protobuf descriptor file path"))
+ "or binary data representing the Protobuf descriptor file path"))
}
val optionsCheck = options.dataType match {
case MapType(StringType, StringType, _) |
@@ -256,6 +301,7 @@ case class ToProtobuf(
}
val descFilePathValue: Option[Array[Byte]] = descFilePath.eval() match {
case s: UTF8String => Some(ProtobufHelper.readDescriptorFileContent(s.toString))
+ case bytes: Array[Byte] => Some(bytes)
case null => None
}
val optionsValue: Map[String, String] = options.eval() match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
index bbf554d384b12..487985b4770ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala
@@ -126,7 +126,7 @@ object VariantExpressionEvalUtils {
buildVariant(builder, element, elementType)
}
builder.finishWritingArray(start, offsets)
- case MapType(StringType, valueType, _) =>
+ case MapType(_: StringType, valueType, _) =>
val data = input.asInstanceOf[MapData]
val keys = data.keyArray()
val values = data.valueArray()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index bd956fa5c00e1..2c8ca1e8bb2bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.json.JsonInferSchema
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET}
import org.apache.spark.sql.catalyst.trees.UnaryLike
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, QuotingUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
@@ -117,6 +117,73 @@ case class IsVariantNull(child: Expression) extends UnaryExpression
copy(child = newChild)
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Convert a nested input (array/map/struct) into a variant where maps and structs are converted to variant objects which are unordered unlike SQL structs. Input maps can only have string keys.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(named_struct('a', 1, 'b', 2));
+ {"a":1,"b":2}
+ > SELECT _FUNC_(array(1, 2, 3));
+ [1,2,3]
+ > SELECT _FUNC_(array(named_struct('a', 1)));
+ [{"a":1}]
+ > SELECT _FUNC_(array(map("a", 2)));
+ [{"a":2}]
+ """,
+ since = "4.0.0",
+ group = "variant_funcs")
+// scalastyle:on line.size.limit
+case class ToVariantObject(child: Expression)
+ extends UnaryExpression
+ with NullIntolerant
+ with QueryErrorsBase {
+
+ override val dataType: DataType = VariantType
+
+ // Only accept nested types at the root but any types can be nested inside.
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val checkResult: Boolean = child.dataType match {
+ case _: StructType | _: ArrayType | _: MapType =>
+ VariantGet.checkDataType(child.dataType, allowStructsAndMaps = true)
+ case _ => false
+ }
+ if (!checkResult) {
+ DataTypeMismatch(
+ errorSubClass = "CAST_WITHOUT_SUGGESTION",
+ messageParameters =
+ Map("srcType" -> toSQLType(child.dataType), "targetType" -> toSQLType(VariantType)))
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override def prettyName: String = "to_variant_object"
+
+ override protected def withNewChildInternal(newChild: Expression): ToVariantObject =
+ copy(child = newChild)
+
+ protected override def nullSafeEval(input: Any): Any =
+ VariantExpressionEvalUtils.castToVariant(input, child.dataType)
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val childCode = child.genCode(ctx)
+ val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$")
+ val fromArg = ctx.addReferenceObj("from", child.dataType)
+ val javaType = JavaCode.javaType(VariantType)
+ val code =
+ code"""
+ ${childCode.code}
+ boolean ${ev.isNull} = ${childCode.isNull};
+ $javaType ${ev.value} = ${CodeGenerator.defaultValue(VariantType)};
+ if (!${childCode.isNull}) {
+ ${ev.value} = $cls.castToVariant(${childCode.value}, $fromArg);
+ }
+ """
+ ev.copy(code = code)
+ }
+}
+
object VariantPathParser extends RegexParsers {
// A path segment in the `VariantGet` expression represents either an object key access or an
// array index access.
@@ -260,13 +327,16 @@ case object VariantGet {
* Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
* of them. For nested types, we reject map types with a non-string key type.
*/
- def checkDataType(dataType: DataType): Boolean = dataType match {
+ def checkDataType(dataType: DataType, allowStructsAndMaps: Boolean = true): Boolean =
+ dataType match {
case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType |
VariantType | _: DayTimeIntervalType | _: YearMonthIntervalType =>
true
- case ArrayType(elementType, _) => checkDataType(elementType)
- case MapType(_: StringType, valueType, _) => checkDataType(valueType)
- case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
+ case ArrayType(elementType, _) => checkDataType(elementType, allowStructsAndMaps)
+ case MapType(_: StringType, valueType, _) if allowStructsAndMaps =>
+ checkDataType(valueType, allowStructsAndMaps)
+ case StructType(fields) if allowStructsAndMaps =>
+ fields.forall(f => checkDataType(f.dataType, allowStructsAndMaps))
case _ => false
}
@@ -635,7 +705,7 @@ object VariantExplode {
> SELECT _FUNC_(parse_json('null'));
VOID
> SELECT _FUNC_(parse_json('[{"b":true,"a":0}]'));
- ARRAY>
+ ARRAY
org.apache.spark
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 92709ff29a1ca..b64637f7d2472 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -63,7 +63,7 @@ object Connect {
"conservatively use 70% of it because the size is not accurate but estimated.")
.version("3.4.0")
.bytesConf(ByteUnit.BYTE)
- .createWithDefault(4 * 1024 * 1024)
+ .createWithDefault(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE)
val CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE =
buildStaticConf("spark.connect.grpc.maxInboundMessageSize")
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
index a1881765a416c..72c77fd033d76 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListener.scala
@@ -144,7 +144,9 @@ private[connect] class ConnectProgressExecutionListener extends SparkListener wi
tracker.stages.get(taskEnd.stageId).foreach { stage =>
stage.update { i =>
i.completedTasks += 1
- i.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead
+ i.inputBytesRead += Option(taskEnd.taskMetrics)
+ .map(_.inputMetrics.bytesRead)
+ .getOrElse(0L)
}
}
// This should never become negative, simply reset to zero if it does.
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
index 3e360372d5600..051093fcad277 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteGrpcResponseSender.scala
@@ -142,7 +142,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
* client, but rather enqueued to in the response observer.
*/
private def enqueueProgressMessage(force: Boolean = false): Unit = {
- if (executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL) > 0) {
+ val progressReportInterval = executeHolder.sessionHolder.session.sessionState.conf
+ .getConf(CONNECT_PROGRESS_REPORT_INTERVAL)
+ if (progressReportInterval > 0) {
SparkConnectService.executionListener.foreach { listener =>
// It is possible, that the tracker is no longer available and in this
// case we simply ignore it and do not send any progress message. This avoids
@@ -240,8 +242,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
// monitor, and will notify upon state change.
if (response.isEmpty) {
// Wake up more frequently to send the progress updates.
- val progressTimeout =
- executeHolder.sessionHolder.session.conf.get(CONNECT_PROGRESS_REPORT_INTERVAL)
+ val progressTimeout = executeHolder.sessionHolder.session.sessionState.conf
+ .getConf(CONNECT_PROGRESS_REPORT_INTERVAL)
// If the progress feature is disabled, wait for the deadline.
val timeout = if (progressTimeout > 0) {
progressTimeout
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 58e61badaf370..33c9edb1cd21a 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -40,14 +40,14 @@ import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
-import org.apache.spark.internal.{Logging, MDC}
+import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
-import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
+import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -78,9 +78,8 @@ import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
-import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils}
+import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, TypedAggUtils}
import org.apache.spark.sql.internal.ExpressionUtils.column
-import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst}
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -204,6 +203,7 @@ class SparkConnectPlanner(
transformCachedLocalRelation(rel.getCachedLocalRelation)
case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot)
+ case proto.Relation.RelTypeCase.TRANSPOSE => transformTranspose(rel.getTranspose)
case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
@@ -1126,6 +1126,13 @@ class SparkConnectPlanner(
UnresolvedHint(rel.getName, params, transformRelation(rel.getInput))
}
+ private def transformTranspose(rel: proto.Transpose): LogicalPlan = {
+ val child = transformRelation(rel.getInput)
+ val indices = rel.getIndexColumnsList.asScala.map(transformExpression).toSeq
+
+ UnresolvedTranspose(indices = indices, child = child)
+ }
+
private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = {
val ids = rel.getIdsList.asScala.toArray.map { expr =>
column(transformExpression(expr))
@@ -1863,67 +1870,15 @@ class SparkConnectPlanner(
}
Some(CatalystDataToAvro(children.head, jsonFormatSchema))
- // Protobuf-specific functions
- case "from_protobuf" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
- val children = fun.getArgumentsList.asScala.map(transformExpression)
- val (msgName, desc, options) = extractProtobufArgs(children.toSeq)
- Some(ProtobufDataToCatalyst(children(0), msgName, desc, options))
-
- case "to_protobuf" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
- val children = fun.getArgumentsList.asScala.map(transformExpression)
- val (msgName, desc, options) = extractProtobufArgs(children.toSeq)
- Some(CatalystDataToProtobuf(children(0), msgName, desc, options))
-
case _ => None
}
}
- private def extractProtobufArgs(children: Seq[Expression]) = {
- val msgName = extractString(children(1), "MessageClassName")
- var desc = Option.empty[Array[Byte]]
- var options = Map.empty[String, String]
- if (children.length == 3) {
- children(2) match {
- case b: Literal => desc = Some(extractBinary(b, "binaryFileDescriptorSet"))
- case o => options = extractMapData(o, "options")
- }
- } else if (children.length == 4) {
- desc = Some(extractBinary(children(2), "binaryFileDescriptorSet"))
- options = extractMapData(children(3), "options")
- }
- (msgName, desc, options)
- }
-
- private def extractBoolean(expr: Expression, field: String): Boolean = expr match {
- case Literal(bool: Boolean, BooleanType) => bool
- case other => throw InvalidPlanInput(s"$field should be a literal boolean, but got $other")
- }
-
- private def extractDouble(expr: Expression, field: String): Double = expr match {
- case Literal(double: Double, DoubleType) => double
- case other => throw InvalidPlanInput(s"$field should be a literal double, but got $other")
- }
-
- private def extractInteger(expr: Expression, field: String): Int = expr match {
- case Literal(int: Int, IntegerType) => int
- case other => throw InvalidPlanInput(s"$field should be a literal integer, but got $other")
- }
-
- private def extractLong(expr: Expression, field: String): Long = expr match {
- case Literal(long: Long, LongType) => long
- case other => throw InvalidPlanInput(s"$field should be a literal long, but got $other")
- }
-
private def extractString(expr: Expression, field: String): String = expr match {
case Literal(s, StringType) if s != null => s.toString
case other => throw InvalidPlanInput(s"$field should be a literal string, but got $other")
}
- private def extractBinary(expr: Expression, field: String): Array[Byte] = expr match {
- case Literal(b: Array[Byte], BinaryType) if b != null => b
- case other => throw InvalidPlanInput(s"$field should be a literal binary, but got $other")
- }
-
@scala.annotation.tailrec
private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match {
case map: CreateMap => ExprUtils.convertToMapData(map)
@@ -1932,23 +1887,6 @@ class SparkConnectPlanner(
case other => throw InvalidPlanInput(s"$field should be created by map, but got $other")
}
- // Extract the schema from a literal string representing a JSON-formatted schema
- private def extractDataTypeFromJSON(exp: proto.Expression): Option[DataType] = {
- exp.getExprTypeCase match {
- case proto.Expression.ExprTypeCase.LITERAL =>
- exp.getLiteral.getLiteralTypeCase match {
- case proto.Expression.Literal.LiteralTypeCase.STRING =>
- try {
- Some(DataType.fromJson(exp.getLiteral.getString))
- } catch {
- case _: Exception => None
- }
- case _ => None
- }
- case _ => None
- }
- }
-
private def transformAlias(alias: proto.Expression.Alias): NamedExpression = {
if (alias.getNameCount == 1) {
val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {
@@ -2380,16 +2318,17 @@ class SparkConnectPlanner(
if (fun.getArgumentsCount != 1) {
throw InvalidPlanInput("reduce requires single child expression")
}
- val udf = fun.getArgumentsList.asScala.map(transformExpression) match {
- case collection.Seq(f: ScalaUDF) =>
- f
+ val udf = fun.getArgumentsList.asScala match {
+ case collection.Seq(e)
+ if e.hasCommonInlineUserDefinedFunction &&
+ e.getCommonInlineUserDefinedFunction.hasScalarScalaUdf =>
+ unpackUdf(e.getCommonInlineUserDefinedFunction)
case other =>
throw InvalidPlanInput(s"reduce should carry a scalar scala udf, but got $other")
}
- assert(udf.outputEncoder.isDefined)
- val tEncoder = udf.outputEncoder.get // (T, T) => T
- val reduce = ReduceAggregator(udf.function)(tEncoder).toColumn.expr
- TypedAggUtils.withInputType(reduce, tEncoder, dataAttributes)
+ val encoder = udf.outputEncoder
+ val reduce = ReduceAggregator(udf.function)(encoder).toColumn.expr
+ TypedAggUtils.withInputType(reduce, encoderFor(encoder), dataAttributes)
}
private def transformExpressionWithTypedReduceExpression(
@@ -3114,10 +3053,13 @@ class SparkConnectPlanner(
sessionHolder.streamingServersideListenerHolder.streamingQueryStartedEventCache.remove(
query.runId.toString))
queryStartedEvent.foreach {
- logDebug(
- s"[SessionId: $sessionId][UserId: $userId][operationId: " +
- s"${executeHolder.operationId}][query id: ${query.id}][query runId: ${query.runId}] " +
- s"Adding QueryStartedEvent to response")
+ logInfo(
+ log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionId)}]" +
+ log"[UserId: ${MDC(LogKeys.USER_ID, userId)}] " +
+ log"[operationId: ${MDC(LogKeys.OPERATION_ID, executeHolder.operationId)}] " +
+ log"[query id: ${MDC(LogKeys.QUERY_ID, query.id)}]" +
+ log"[query runId: ${MDC(LogKeys.QUERY_RUN_ID, query.runId)}] " +
+ log"Adding QueryStartedEvent to response")
e => resultBuilder.setQueryStartedEventJson(e.json)
}
@@ -3496,16 +3438,14 @@ class SparkConnectPlanner(
val notMatchedBySourceActions = transformActions(cmd.getNotMatchedBySourceActionsList)
val sourceDs = Dataset.ofRows(session, transformRelation(cmd.getSourceTablePlan))
- var mergeInto = sourceDs
+ val mergeInto = sourceDs
.mergeInto(cmd.getTargetTableName, column(transformExpression(cmd.getMergeCondition)))
- .withNewMatchedActions(matchedActions: _*)
- .withNewNotMatchedActions(notMatchedActions: _*)
- .withNewNotMatchedBySourceActions(notMatchedBySourceActions: _*)
-
- mergeInto = if (cmd.getWithSchemaEvolution) {
+ .asInstanceOf[MergeIntoWriterImpl[Row]]
+ mergeInto.matchedActions ++= matchedActions
+ mergeInto.notMatchedActions ++= notMatchedActions
+ mergeInto.notMatchedBySourceActions ++= notMatchedBySourceActions
+ if (cmd.getWithSchemaEvolution) {
mergeInto.withSchemaEvolution()
- } else {
- mergeInto
}
mergeInto.merge()
executeHolder.eventsManager.postFinished()
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index ec7ebbe92d72e..dc349c3e33251 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql.connect.service
-import java.util.UUID
-
import scala.collection.mutable
import scala.jdk.CollectionConverters._
-import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Observation
@@ -35,30 +33,19 @@ import org.apache.spark.util.SystemClock
* Object used to hold the Spark Connect execution state.
*/
private[connect] class ExecuteHolder(
+ val executeKey: ExecuteKey,
val request: proto.ExecutePlanRequest,
val sessionHolder: SessionHolder)
extends Logging {
val session = sessionHolder.session
- val operationId = if (request.hasOperationId) {
- try {
- UUID.fromString(request.getOperationId).toString
- } catch {
- case _: IllegalArgumentException =>
- throw new SparkSQLException(
- errorClass = "INVALID_HANDLE.FORMAT",
- messageParameters = Map("handle" -> request.getOperationId))
- }
- } else {
- UUID.randomUUID().toString
- }
-
/**
* Tag that is set for this execution on SparkContext, via SparkContext.addJobTag. Used
* (internally) for cancellation of the Spark Jobs ran by this execution.
*/
- val jobTag = ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, operationId)
+ val jobTag =
+ ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, executeKey.operationId)
/**
* Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group
@@ -278,7 +265,7 @@ private[connect] class ExecuteHolder(
request = request,
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
- operationId = operationId,
+ operationId = executeKey.operationId,
jobTag = jobTag,
sparkSessionTags = sparkSessionTags,
reattachable = reattachable,
@@ -289,7 +276,10 @@ private[connect] class ExecuteHolder(
}
/** Get key used by SparkConnectExecutionManager global tracker. */
- def key: ExecuteKey = ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId)
+ def key: ExecuteKey = executeKey
+
+ /** Get the operation ID. */
+ def operationId: String = key.operationId
}
/** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 0cb820b39e875..e56d66da3050d 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -444,8 +444,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
*/
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
transform: proto.Relation => LogicalPlan): LogicalPlan = {
- val planCacheEnabled =
- Option(session).forall(_.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
+ val planCacheEnabled = Option(session)
+ .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
// We only cache plans that have a plan ID.
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
index 1ab5f26f90b13..73a20e448be87 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
@@ -31,6 +31,15 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec
try {
executeHolder.eventsManager.postStarted()
executeHolder.start()
+ } catch {
+ // Errors raised before the execution holder has finished spawning a thread are considered
+ // plan execution failure, and the client should not try reattaching it afterwards.
+ case t: Throwable =>
+ SparkConnectService.executionManager.removeExecuteHolder(executeHolder.key)
+ throw t
+ }
+
+ try {
val responseSender =
new ExecuteGrpcResponseSender[proto.ExecutePlanResponse](executeHolder, responseObserver)
executeHolder.runGrpcResponseSender(responseSender)
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
index 6681a5f509c6e..d66964b8d34bd 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.connect.service
-import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
-import javax.annotation.concurrent.GuardedBy
+import java.util.UUID
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
+import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
@@ -36,6 +37,24 @@ import org.apache.spark.util.ThreadUtils
// Unique key identifying execution by combination of user, session and operation id
case class ExecuteKey(userId: String, sessionId: String, operationId: String)
+object ExecuteKey {
+ def apply(request: proto.ExecutePlanRequest, sessionHolder: SessionHolder): ExecuteKey = {
+ val operationId = if (request.hasOperationId) {
+ try {
+ UUID.fromString(request.getOperationId).toString
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.FORMAT",
+ messageParameters = Map("handle" -> request.getOperationId))
+ }
+ } else {
+ UUID.randomUUID().toString
+ }
+ ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId)
+ }
+}
+
/**
* Global tracker of all ExecuteHolder executions.
*
@@ -44,11 +63,9 @@ case class ExecuteKey(userId: String, sessionId: String, operationId: String)
*/
private[connect] class SparkConnectExecutionManager() extends Logging {
- /** Hash table containing all current executions. Guarded by executionsLock. */
- @GuardedBy("executionsLock")
- private val executions: mutable.HashMap[ExecuteKey, ExecuteHolder] =
- new mutable.HashMap[ExecuteKey, ExecuteHolder]()
- private val executionsLock = new Object
+ /** Concurrent hash table containing all the current executions. */
+ private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] =
+ new ConcurrentHashMap[ExecuteKey, ExecuteHolder]()
/** Graveyard of tombstones of executions that were abandoned and removed. */
private val abandonedTombstones = CacheBuilder
@@ -56,12 +73,12 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
.maximumSize(SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE))
.build[ExecuteKey, ExecuteInfo]()
- /** None if there are no executions. Otherwise, the time when the last execution was removed. */
- @GuardedBy("executionsLock")
- private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis())
+ /** The time when the last execution was removed. */
+ private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis())
/** Executor for the periodic maintenance */
- private var scheduledExecutor: Option[ScheduledExecutorService] = None
+ private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+ new AtomicReference[ScheduledExecutorService]()
/**
* Create a new ExecuteHolder and register it with this global manager and with its session.
@@ -76,27 +93,30 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)
- val executeHolder = new ExecuteHolder(request, sessionHolder)
- executionsLock.synchronized {
- // Check if the operation already exists, both in active executions, and in the graveyard
- // of tombstones of executions that have been abandoned.
- // The latter is to prevent double execution when a client retries execution, thinking it
- // never reached the server, but in fact it did, and already got removed as abandoned.
- if (executions.get(executeHolder.key).isDefined) {
- throw new SparkSQLException(
- errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
- messageParameters = Map("handle" -> executeHolder.operationId))
- }
- if (getAbandonedTombstone(executeHolder.key).isDefined) {
- throw new SparkSQLException(
- errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
- messageParameters = Map("handle" -> executeHolder.operationId))
- }
- sessionHolder.addExecuteHolder(executeHolder)
- executions.put(executeHolder.key, executeHolder)
- lastExecutionTimeMs = None
- logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.")
- }
+ val executeKey = ExecuteKey(request, sessionHolder)
+ val executeHolder = executions.compute(
+ executeKey,
+ (executeKey, oldExecuteHolder) => {
+ // Check if the operation already exists, either in the active execution map, or in the
+ // graveyard of tombstones of executions that have been abandoned. The latter is to prevent
+ // double executions when the client retries, thinking it never reached the server, but in
+ // fact it did, and already got removed as abandoned.
+ if (oldExecuteHolder != null) {
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
+ messageParameters = Map("handle" -> executeKey.operationId))
+ }
+ if (getAbandonedTombstone(executeKey).isDefined) {
+ throw new SparkSQLException(
+ errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
+ messageParameters = Map("handle" -> executeKey.operationId))
+ }
+ new ExecuteHolder(executeKey, request, sessionHolder)
+ })
+
+ sessionHolder.addExecuteHolder(executeHolder)
+
+ logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.")
schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started.
@@ -108,43 +128,46 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* execution if still running, free all resources.
*/
private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean = false): Unit = {
- var executeHolder: Option[ExecuteHolder] = None
- executionsLock.synchronized {
- executeHolder = executions.remove(key)
- executeHolder.foreach { e =>
- // Put into abandonedTombstones under lock, so that if it's accessed it will end up
- // with INVALID_HANDLE.OPERATION_ABANDONED error.
- if (abandoned) {
- abandonedTombstones.put(key, e.getExecuteInfo)
- }
- e.sessionHolder.removeExecuteHolder(e.operationId)
- }
- if (executions.isEmpty) {
- lastExecutionTimeMs = Some(System.currentTimeMillis())
- }
- logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")
+ val executeHolder = executions.get(key)
+
+ if (executeHolder == null) {
+ return
}
- // close the execution outside the lock
- executeHolder.foreach { e =>
- e.close()
- if (abandoned) {
- // Update in abandonedTombstones: above it wasn't yet updated with closedTime etc.
- abandonedTombstones.put(key, e.getExecuteInfo)
- }
+
+ // Put into abandonedTombstones before removing it from executions, so that the client ends up
+ // getting an INVALID_HANDLE.OPERATION_ABANDONED error on a retry.
+ if (abandoned) {
+ abandonedTombstones.put(key, executeHolder.getExecuteInfo)
+ }
+
+ // Remove the execution from the map *after* putting it in abandonedTombstones.
+ executions.remove(key)
+ executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId)
+
+ updateLastExecutionTime()
+
+ logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")
+
+ executeHolder.close()
+ if (abandoned) {
+ // Update in abandonedTombstones: above it wasn't yet updated with closedTime etc.
+ abandonedTombstones.put(key, executeHolder.getExecuteInfo)
}
}
private[connect] def getExecuteHolder(key: ExecuteKey): Option[ExecuteHolder] = {
- executionsLock.synchronized {
- executions.get(key)
- }
+ Option(executions.get(key))
}
private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
- val sessionExecutionHolders = executionsLock.synchronized {
- executions.filter(_._2.sessionHolder.key == key)
- }
- sessionExecutionHolders.foreach { case (_, executeHolder) =>
+ var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]()
+ executions.forEach((_, executeHolder) => {
+ if (executeHolder.sessionHolder.key == key) {
+ sessionExecutionHolders += executeHolder
+ }
+ })
+
+ sessionExecutionHolders.foreach { executeHolder =>
val info = executeHolder.getExecuteInfo
logInfo(
log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.")
@@ -161,11 +184,11 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* If there are no executions, return Left with System.currentTimeMillis of last active
* execution. Otherwise return Right with list of ExecuteInfo of all executions.
*/
- def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = executionsLock.synchronized {
+ def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = {
if (executions.isEmpty) {
- Left(lastExecutionTimeMs.get)
+ Left(lastExecutionTimeMs.getAcquire())
} else {
- Right(executions.values.map(_.getExecuteInfo).toBuffer.toSeq)
+ Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq)
}
}
@@ -177,17 +200,24 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
abandonedTombstones.asMap.asScala.values.toSeq
}
- private[connect] def shutdown(): Unit = executionsLock.synchronized {
- scheduledExecutor.foreach { executor =>
+ private[connect] def shutdown(): Unit = {
+ val executor = scheduledExecutor.getAndSet(null)
+ if (executor != null) {
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
- scheduledExecutor = None
+
// note: this does not cleanly shut down the executions, but the server is shutting down.
executions.clear()
abandonedTombstones.invalidateAll()
- if (lastExecutionTimeMs.isEmpty) {
- lastExecutionTimeMs = Some(System.currentTimeMillis())
- }
+
+ updateLastExecutionTime()
+ }
+
+ /**
+ * Updates the last execution time after the last execution has been removed.
+ */
+ private def updateLastExecutionTime(): Unit = {
+ lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis()))
}
/**
@@ -195,16 +225,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* for executions that have not been closed, but are left with no RPC attached to them, and
* removes them after a timeout.
*/
- private def schedulePeriodicChecks(): Unit = executionsLock.synchronized {
- scheduledExecutor match {
- case Some(_) => // Already running.
- case None =>
+ private def schedulePeriodicChecks(): Unit = {
+ var executor = scheduledExecutor.getAcquire()
+ if (executor == null) {
+ executor = Executors.newSingleThreadScheduledExecutor()
+ if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
val interval = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL)
logInfo(
log"Starting thread for cleanup of abandoned executions every " +
log"${MDC(LogKeys.INTERVAL, interval)} ms")
- scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
- scheduledExecutor.get.scheduleAtFixedRate(
+ executor.scheduleAtFixedRate(
() => {
try {
val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT)
@@ -216,6 +246,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
interval,
interval,
TimeUnit.MILLISECONDS)
+ }
}
}
@@ -225,19 +256,18 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
// Find any detached executions that expired and should be removed.
val toRemove = new mutable.ArrayBuffer[ExecuteHolder]()
- executionsLock.synchronized {
- val nowMs = System.currentTimeMillis()
-
- executions.values.foreach { executeHolder =>
- executeHolder.lastAttachedRpcTimeMs match {
- case Some(detached) =>
- if (detached + timeout <= nowMs) {
- toRemove += executeHolder
- }
- case _ => // execution is active
- }
+ val nowMs = System.currentTimeMillis()
+
+ executions.forEach((_, executeHolder) => {
+ executeHolder.lastAttachedRpcTimeMs match {
+ case Some(detached) =>
+ if (detached + timeout <= nowMs) {
+ toRemove += executeHolder
+ }
+ case _ => // execution is active
}
- }
+ })
+
// .. and remove them.
toRemove.foreach { executeHolder =>
val info = executeHolder.getExecuteInfo
@@ -250,16 +280,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}
// For testing.
- private[connect] def setAllRPCsDeadline(deadlineMs: Long) = executionsLock.synchronized {
- executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
+ private[connect] def setAllRPCsDeadline(deadlineMs: Long) = {
+ executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
}
// For testing.
- private[connect] def interruptAllRPCs() = executionsLock.synchronized {
- executions.values.foreach(_.interruptGrpcResponseSenders())
+ private[connect] def interruptAllRPCs() = {
+ executions.values().asScala.foreach(_.interruptGrpcResponseSenders())
}
- private[connect] def listExecuteHolders: Seq[ExecuteHolder] = executionsLock.synchronized {
- executions.values.toSeq
+ private[connect] def listExecuteHolders: Seq[ExecuteHolder] = {
+ executions.values().asScala.toSeq
}
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
index 5b2205757648f..7a0c067ab430b 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala
@@ -160,9 +160,11 @@ private[sql] class SparkConnectListenerBusListener(
}
override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
- logDebug(
- s"[SessionId: ${sessionHolder.sessionId}][UserId: ${sessionHolder.userId}] " +
- s"Sending QueryTerminatedEvent to client, id: ${event.id} runId: ${event.runId}.")
+ logInfo(
+ log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionHolder.sessionId)}]" +
+ log"[UserId: ${MDC(LogKeys.USER_ID, sessionHolder.userId)}] " +
+ log"Sending QueryTerminatedEvent to client, id: ${MDC(LogKeys.QUERY_ID, event.id)} " +
+ log"runId: ${MDC(LogKeys.QUERY_RUN_ID, event.runId)}.")
send(event.json, StreamingQueryEventType.QUERY_TERMINATED_EVENT)
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index edaaa640bf12e..4ca3a80bfb985 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.connect.service
import java.util.UUID
-import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
-import javax.annotation.concurrent.GuardedBy
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
+import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
@@ -40,10 +40,8 @@ import org.apache.spark.util.ThreadUtils
*/
class SparkConnectSessionManager extends Logging {
- private val sessionsLock = new Object
-
- @GuardedBy("sessionsLock")
- private val sessionStore = mutable.HashMap[SessionKey, SessionHolder]()
+ private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
+ new ConcurrentHashMap[SessionKey, SessionHolder]()
private val closedSessionsCache =
CacheBuilder
@@ -52,7 +50,8 @@ class SparkConnectSessionManager extends Logging {
.build[SessionKey, SessionHolderInfo]()
/** Executor for the periodic maintenance */
- private var scheduledExecutor: Option[ScheduledExecutorService] = None
+ private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+ new AtomicReference[ScheduledExecutorService]()
private def validateSessionId(
key: SessionKey,
@@ -74,8 +73,6 @@ class SparkConnectSessionManager extends Logging {
val holder = getSession(
key,
Some(() => {
- // Executed under sessionsState lock in getSession, to guard against concurrent removal
- // and insertion into closedSessionsCache.
validateSessionCreate(key)
val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession())
holder.initializeSession()
@@ -121,43 +118,39 @@ class SparkConnectSessionManager extends Logging {
private def getSession(key: SessionKey, default: Option[() => SessionHolder]): SessionHolder = {
schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started yet.
- sessionsLock.synchronized {
- // try to get existing session from store
- val sessionOpt = sessionStore.get(key)
- // create using default if missing
- val session = sessionOpt match {
- case Some(s) => s
- case None =>
- default match {
- case Some(callable) =>
- val session = callable()
- sessionStore.put(key, session)
- session
- case None =>
- null
- }
- }
- // record access time before returning
- session match {
- case null =>
- null
- case s: SessionHolder =>
- s.updateAccessTime()
- s
- }
+ // Get the existing session from the store or create a new one.
+ val session = default match {
+ case Some(callable) =>
+ sessionStore.computeIfAbsent(key, _ => callable())
+ case None =>
+ sessionStore.get(key)
+ }
+
+ // Record the access time before returning the session holder.
+ if (session != null) {
+ session.updateAccessTime()
}
+
+ session
}
// Removes session from sessionStore and returns it.
private def removeSessionHolder(key: SessionKey): Option[SessionHolder] = {
var sessionHolder: Option[SessionHolder] = None
- sessionsLock.synchronized {
- sessionHolder = sessionStore.remove(key)
- sessionHolder.foreach { s =>
- // Put into closedSessionsCache, so that it cannot get accidentally recreated
- // by getOrCreateIsolatedSession.
- closedSessionsCache.put(s.key, s.getSessionHolderInfo)
- }
+
+ // The session holder should remain in the session store until it is added to the closed session
+ // cache, because of a subtle data race: a new session with the same key can be created if the
+ // closed session cache does not contain the key right after the key has been removed from the
+ // session store.
+ sessionHolder = Option(sessionStore.get(key))
+
+ sessionHolder.foreach { s =>
+ // Put into closedSessionsCache to prevent the same session from being recreated by
+ // getOrCreateIsolatedSession.
+ closedSessionsCache.put(s.key, s.getSessionHolderInfo)
+
+ // Then, remove the session holder from the session store.
+ sessionStore.remove(key)
}
sessionHolder
}
@@ -171,26 +164,26 @@ class SparkConnectSessionManager extends Logging {
def closeSession(key: SessionKey): Unit = {
val sessionHolder = removeSessionHolder(key)
- // Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by
- // getOrCreateIsolatedSession.
+ // Rest of the cleanup: the session cannot be accessed anymore by getOrCreateIsolatedSession.
sessionHolder.foreach(shutdownSessionHolder(_))
}
- private[connect] def shutdown(): Unit = sessionsLock.synchronized {
- scheduledExecutor.foreach { executor =>
+ private[connect] def shutdown(): Unit = {
+ val executor = scheduledExecutor.getAndSet(null)
+ if (executor != null) {
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
- scheduledExecutor = None
+
// note: this does not cleanly shut down the sessions, but the server is shutting down.
sessionStore.clear()
closedSessionsCache.invalidateAll()
}
- def listActiveSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
- sessionStore.values.map(_.getSessionHolderInfo).toSeq
+ def listActiveSessions: Seq[SessionHolderInfo] = {
+ sessionStore.values().asScala.map(_.getSessionHolderInfo).toSeq
}
- def listClosedSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
+ def listClosedSessions: Seq[SessionHolderInfo] = {
closedSessionsCache.asMap.asScala.values.toSeq
}
@@ -199,16 +192,16 @@ class SparkConnectSessionManager extends Logging {
*
* The checks are looking to remove sessions that expired.
*/
- private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized {
- scheduledExecutor match {
- case Some(_) => // Already running.
- case None =>
+ private def schedulePeriodicChecks(): Unit = {
+ var executor = scheduledExecutor.getAcquire()
+ if (executor == null) {
+ executor = Executors.newSingleThreadScheduledExecutor()
+ if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
val interval = SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL)
logInfo(
log"Starting thread for cleanup of expired sessions every " +
log"${MDC(INTERVAL, interval)} ms")
- scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
- scheduledExecutor.get.scheduleAtFixedRate(
+ executor.scheduleAtFixedRate(
() => {
try {
val defaultInactiveTimeoutMs =
@@ -221,6 +214,7 @@ class SparkConnectSessionManager extends Logging {
interval,
interval,
TimeUnit.MILLISECONDS)
+ }
}
}
@@ -246,34 +240,27 @@ class SparkConnectSessionManager extends Logging {
timeoutMs != -1 && info.lastAccessTimeMs + timeoutMs <= nowMs
}
- sessionsLock.synchronized {
- val nowMs = System.currentTimeMillis()
- sessionStore.values.foreach { sessionHolder =>
- if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
- toRemove += sessionHolder
- }
+ val nowMs = System.currentTimeMillis()
+ sessionStore.forEach((_, sessionHolder) => {
+ if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
+ toRemove += sessionHolder
}
- }
+ })
+
// .. and remove them.
toRemove.foreach { sessionHolder =>
- // This doesn't use closeSession to be able to do the extra last chance check under lock.
- val removedSession = sessionsLock.synchronized {
- // Last chance - check expiration time and remove under lock if expired.
- val info = sessionHolder.getSessionHolderInfo
- if (shouldExpire(info, System.currentTimeMillis())) {
- logInfo(
- log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
- log"and will be closed.")
- removeSessionHolder(info.key)
- } else {
- None
+ val info = sessionHolder.getSessionHolderInfo
+ if (shouldExpire(info, System.currentTimeMillis())) {
+ logInfo(
+ log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " +
+ log"and will be closed.")
+ removeSessionHolder(info.key)
+ try {
+ shutdownSessionHolder(sessionHolder)
+ } catch {
+ case NonFatal(ex) => logWarning("Unexpected exception closing session", ex)
}
}
- // do shutdown and cleanup outside of lock.
- try removedSession.foreach(shutdownSessionHolder(_))
- catch {
- case NonFatal(ex) => logWarning("Unexpected exception closing session", ex)
- }
}
logInfo("Finished periodic run of SparkConnectSessionManager maintenance.")
}
@@ -309,7 +296,7 @@ class SparkConnectSessionManager extends Logging {
/**
* Used for testing
*/
- private[connect] def invalidateAllSessions(): Unit = sessionsLock.synchronized {
+ private[connect] def invalidateAllSessions(): Unit = {
periodicMaintenance(defaultInactiveTimeoutMs = 0L, ignoreCustomTimeout = true)
assert(sessionStore.isEmpty)
closedSessionsCache.invalidateAll()
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
index 03719ddd87419..8241672d5107b 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicReference
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
@@ -185,10 +186,10 @@ private[connect] class SparkConnectStreamingQueryCache(
// Visible for testing.
private[service] def shutdown(): Unit = queryCacheLock.synchronized {
- scheduledExecutor.foreach { executor =>
+ val executor = scheduledExecutor.getAndSet(null)
+ if (executor != null) {
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
- scheduledExecutor = None
}
@GuardedBy("queryCacheLock")
@@ -199,19 +200,19 @@ private[connect] class SparkConnectStreamingQueryCache(
private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]]
private val taggedQueriesLock = new Object
- @GuardedBy("queryCacheLock")
- private var scheduledExecutor: Option[ScheduledExecutorService] = None
+ private var scheduledExecutor: AtomicReference[ScheduledExecutorService] =
+ new AtomicReference[ScheduledExecutorService]()
/** Schedules periodic checks if it is not already scheduled */
- private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized {
- scheduledExecutor match {
- case Some(_) => // Already running.
- case None =>
+ private def schedulePeriodicChecks(): Unit = {
+ var executor = scheduledExecutor.getAcquire()
+ if (executor == null) {
+ executor = Executors.newSingleThreadScheduledExecutor()
+ if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) {
logInfo(
log"Starting thread for polling streaming sessions " +
log"every ${MDC(DURATION, sessionPollingPeriod.toMillis)}")
- scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor())
- scheduledExecutor.get.scheduleAtFixedRate(
+ executor.scheduleAtFixedRate(
() => {
try periodicMaintenance()
catch {
@@ -221,6 +222,7 @@ private[connect] class SparkConnectStreamingQueryCache(
sessionPollingPeriod.toMillis,
sessionPollingPeriod.toMillis,
TimeUnit.MILLISECONDS)
+ }
}
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 355048cf30363..f1636ed1ef092 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -205,7 +205,9 @@ private[connect] object ErrorUtils extends Logging {
case _ =>
}
- if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) {
+ val enrichErrorEnabled = sessionHolderOpt.exists(
+ _.session.sessionState.conf.getConf(Connect.CONNECT_ENRICH_ERROR_ENABLED))
+ if (enrichErrorEnabled) {
// Generate a new unique key for this exception.
val errorId = UUID.randomUUID().toString
@@ -216,9 +218,10 @@ private[connect] object ErrorUtils extends Logging {
}
lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
+ val stackTraceEnabled = sessionHolderOpt.exists(
+ _.session.sessionState.conf.getConf(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
val withStackTrace =
- if (sessionHolderOpt.exists(
- _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty)) {
+ if (stackTraceEnabled && stackTrace.nonEmpty) {
val maxSize = Math.min(
SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE),
maxMetadataSize)
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
index 4f6daded402a9..29ad97ad9fbe8 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
@@ -179,6 +179,11 @@ class ProtoToParsedPlanTestSuite
logError(log"Skipping ${MDC(PATH, fileName)}")
return
}
+ // TODO: enable below by SPARK-49487
+ if (fileName.contains("transpose")) {
+ logError(log"Skipping ${MDC(PATH, fileName)} because of SPARK-49487")
+ return
+ }
val name = fileName.stripSuffix(".proto.bin")
test(name) {
val relation = readRelation(file)
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
index 7c1b9362425d9..df5df23e77505 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ConnectProgressExecutionListenerSuite.scala
@@ -79,11 +79,16 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu
}
}
- test("taskDone") {
+ def testTaskDone(metricsPopulated: Boolean): Unit = {
val listener = new ConnectProgressExecutionListener
listener.registerJobTag(testTag)
listener.onJobStart(testJobStart)
+ val metricsOrNull = if (metricsPopulated) {
+ testStage1Task1Metrics
+ } else {
+ null
+ }
// Finish the tasks
val taskEnd = SparkListenerTaskEnd(
1,
@@ -92,7 +97,7 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu
Success,
testStage1Task1,
testStage1Task1ExecutorMetrics,
- testStage1Task1Metrics)
+ metricsOrNull)
val t = listener.trackedTags(testTag)
var yielded = false
@@ -117,7 +122,11 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu
assert(stages.map(_.numTasks).sum == 2)
assert(stages.map(_.completedTasks).sum == 1)
assert(stages.size == 2)
- assert(stages.map(_.inputBytesRead).sum == 500)
+ if (metricsPopulated) {
+ assert(stages.map(_.inputBytesRead).sum == 500)
+ } else {
+ assert(stages.map(_.inputBytesRead).sum == 0)
+ }
assert(
stages
.map(_.completed match {
@@ -140,7 +149,11 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu
assert(stages.map(_.numTasks).sum == 2)
assert(stages.map(_.completedTasks).sum == 1)
assert(stages.size == 2)
- assert(stages.map(_.inputBytesRead).sum == 500)
+ if (metricsPopulated) {
+ assert(stages.map(_.inputBytesRead).sum == 500)
+ } else {
+ assert(stages.map(_.inputBytesRead).sum == 0)
+ }
assert(
stages
.map(_.completed match {
@@ -153,4 +166,12 @@ class ConnectProgressExecutionListenerSuite extends SparkFunSuite with MockitoSu
assert(yielded, "Must updated with results")
}
+ test("taskDone - populated metrics") {
+ testTaskDone(metricsPopulated = true)
+ }
+
+ test("taskDone - null metrics") {
+ testTaskDone(metricsPopulated = false)
+ }
+
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
index 25e6cc48a1998..2606284c25bd5 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala
@@ -429,4 +429,27 @@ class ReattachableExecuteSuite extends SparkConnectServerTest {
val abandonedExecutions = manager.listAbandonedExecutions
assert(abandonedExecutions.forall(_.operationId != dummyOpId))
}
+
+ test("SPARK-49492: reattach must not succeed on an inactive execution holder") {
+ withRawBlockingStub { stub =>
+ val operationId = UUID.randomUUID().toString
+
+ // supply an invalid plan so that the execute plan handler raises an error
+ val iter = stub.executePlan(
+ buildExecutePlanRequest(proto.Plan.newBuilder().build(), operationId = operationId))
+
+ // expect that the execution fails before spawning an execute thread
+ val ee = intercept[StatusRuntimeException] {
+ iter.next()
+ }
+ assert(ee.getMessage.contains("INTERNAL"))
+
+ // reattach must fail
+ val reattach = stub.reattachExecute(buildReattachExecuteRequest(operationId, None))
+ val re = intercept[StatusRuntimeException] {
+ reattach.hasNext()
+ }
+ assert(re.getMessage.contains("INVALID_HANDLE.OPERATION_NOT_FOUND"))
+ }
+ }
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 03bf5a4c10dbc..cad7fe6370827 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -1042,7 +1042,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
analyzePlan(
transform(connectTestRelation.observe("my_metric", "id".protoAttr.cast("string"))))
},
- errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
+ condition = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\""))
val connectPlan2 =
@@ -1073,7 +1073,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
transform(
connectTestRelation.observe(Observation("my_metric"), "id".protoAttr.cast("string"))))
},
- errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
+ condition = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\""))
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 579fdb47aef3c..62146f19328a8 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -871,10 +871,16 @@ class SparkConnectServiceSuite
class VerifyEvents(val sparkContext: SparkContext) {
val listener: MockSparkListener = new MockSparkListener()
val listenerBus = sparkContext.listenerBus
+ val EVENT_WAIT_TIMEOUT = timeout(10.seconds)
val LISTENER_BUS_TIMEOUT = 30000
def executeHolder: ExecuteHolder = {
- assert(listener.executeHolder.isDefined)
- listener.executeHolder.get
+ // An ExecuteHolder shall be set eventually through MockSparkListener
+ Eventually.eventually(EVENT_WAIT_TIMEOUT) {
+ assert(
+ listener.executeHolder.isDefined,
+ s"No events have been posted in $EVENT_WAIT_TIMEOUT")
+ listener.executeHolder.get
+ }
}
def onNext(v: proto.ExecutePlanResponse): Unit = {
if (v.hasSchema) {
@@ -891,8 +897,10 @@ class SparkConnectServiceSuite
def onCompleted(producedRowCount: Option[Long] = None): Unit = {
assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount)
// The eventsManager is closed asynchronously
- Eventually.eventually(timeout(1.seconds)) {
- assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
+ Eventually.eventually(EVENT_WAIT_TIMEOUT) {
+ assert(
+ executeHolder.eventsManager.status == ExecuteStatus.Closed,
+ s"Execution has not been completed in $EVENT_WAIT_TIMEOUT")
}
}
def onCanceled(): Unit = {
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
index 512cdad62b921..2b768875c6e20 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
@@ -217,7 +217,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne
exception = intercept[SparkException] {
SparkConnectPluginRegistry.loadRelationPlugins()
},
- errorClass = "CONNECT.PLUGIN_CTOR_MISSING",
+ condition = "CONNECT.PLUGIN_CTOR_MISSING",
parameters = Map("cls" -> "org.apache.spark.sql.connect.plugin.DummyPluginNoTrivialCtor"))
}
@@ -228,7 +228,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne
exception = intercept[SparkException] {
SparkConnectPluginRegistry.loadRelationPlugins()
},
- errorClass = "CONNECT.PLUGIN_RUNTIME_ERROR",
+ condition = "CONNECT.PLUGIN_RUNTIME_ERROR",
parameters = Map("msg" -> "Bad Plugin Error"))
}
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
index dbe8420eab03d..a9843e261fff8 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala
@@ -374,7 +374,8 @@ class ExecuteEventsManagerSuite
.setClientType(DEFAULT_CLIENT_TYPE)
.build()
- val executeHolder = new ExecuteHolder(executePlanRequest, sessionHolder)
+ val executeKey = ExecuteKey(executePlanRequest, sessionHolder)
+ val executeHolder = new ExecuteHolder(executeKey, executePlanRequest, sessionHolder)
val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK)
eventsManager.status_(executeStatus)
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala
index 8f76d58a31476..f125cb2d5c6c0 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala
@@ -118,7 +118,7 @@ class InterceptorRegistrySuite extends SharedSparkSession {
exception = intercept[SparkException] {
SparkConnectInterceptorRegistry.chainInterceptors(sb)
},
- errorClass = "CONNECT.INTERCEPTOR_CTOR_MISSING",
+ condition = "CONNECT.INTERCEPTOR_CTOR_MISSING",
parameters =
Map("cls" -> "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor"))
}
@@ -132,7 +132,7 @@ class InterceptorRegistrySuite extends SharedSparkSession {
exception = intercept[SparkException] {
SparkConnectInterceptorRegistry.createConfiguredInterceptors()
},
- errorClass = "CONNECT.INTERCEPTOR_CTOR_MISSING",
+ condition = "CONNECT.INTERCEPTOR_CTOR_MISSING",
parameters =
Map("cls" -> "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor"))
}
@@ -144,7 +144,7 @@ class InterceptorRegistrySuite extends SharedSparkSession {
exception = intercept[SparkException] {
SparkConnectInterceptorRegistry.createConfiguredInterceptors()
},
- errorClass = "CONNECT.INTERCEPTOR_RUNTIME_ERROR",
+ condition = "CONNECT.INTERCEPTOR_RUNTIME_ERROR",
parameters = Map("msg" -> "Bad Error"))
}
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index beebe5d2e2dc1..ed2f60afb0096 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -399,7 +399,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
test("Test session plan cache - disabled") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
// Disable plan cache of the session
- sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, false)
+ sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED.key, false)
val planner = new SparkConnectPlanner(sessionHolder)
val query = buildRelation("select 1")
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto
index 113bb8a8b00fb..1ff90f27e173a 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto
@@ -59,6 +59,7 @@ message ImplicitGroupingKeyRequest {
message StateCallCommand {
string stateName = 1;
string schema = 2;
+ TTLConfig ttl = 3;
}
message ValueStateCall {
@@ -101,3 +102,7 @@ enum HandleState {
message SetHandleState {
HandleState state = 1;
}
+
+message TTLConfig {
+ int32 durationMs = 1;
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java
index 542bcd8a68abf..4fbb20be05b7b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java
@@ -15,9 +15,7 @@
* limitations under the License.
*/
// Generated by the protocol buffer compiler. DO NOT EDIT!
-// NO CHECKED-IN PROTOBUF GENCODE
// source: StateMessage.proto
-// Protobuf Java Version: 4.27.1
package org.apache.spark.sql.execution.streaming.state;
@@ -213,7 +211,7 @@ public interface StateRequestOrBuilder extends
*/
org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequestOrBuilder getImplicitGroupingKeyRequestOrBuilder();
- org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest.MethodCase getMethodCase();
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest.MethodCase getMethodCase();
}
/**
* Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StateRequest}
@@ -230,6 +228,18 @@ private StateRequest(com.google.protobuf.GeneratedMessageV3.Builder> builder)
private StateRequest() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new StateRequest();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateRequest_descriptor;
@@ -244,7 +254,6 @@ private StateRequest() {
}
private int methodCase_ = 0;
- @SuppressWarnings("serial")
private java.lang.Object method_;
public enum MethodCase
implements com.google.protobuf.Internal.EnumLite,
@@ -288,7 +297,7 @@ public int getNumber() {
}
public static final int VERSION_FIELD_NUMBER = 1;
- private int version_ = 0;
+ private int version_;
/**
* int32 version = 1;
* @return The version.
@@ -554,13 +563,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateR
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -635,8 +642,8 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
version_ = 0;
+
if (statefulProcessorCallBuilder_ != null) {
statefulProcessorCallBuilder_.clear();
}
@@ -674,36 +681,65 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest(this);
- if (bitField0_ != 0) { buildPartial0(result); }
- buildPartialOneofs(result);
+ result.version_ = version_;
+ if (methodCase_ == 2) {
+ if (statefulProcessorCallBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = statefulProcessorCallBuilder_.build();
+ }
+ }
+ if (methodCase_ == 3) {
+ if (stateVariableRequestBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = stateVariableRequestBuilder_.build();
+ }
+ }
+ if (methodCase_ == 4) {
+ if (implicitGroupingKeyRequestBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = implicitGroupingKeyRequestBuilder_.build();
+ }
+ }
+ result.methodCase_ = methodCase_;
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest result) {
- int from_bitField0_ = bitField0_;
- if (((from_bitField0_ & 0x00000001) != 0)) {
- result.version_ = version_;
- }
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
}
-
- private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest result) {
- result.methodCase_ = methodCase_;
- result.method_ = this.method_;
- if (methodCase_ == 2 &&
- statefulProcessorCallBuilder_ != null) {
- result.method_ = statefulProcessorCallBuilder_.build();
- }
- if (methodCase_ == 3 &&
- stateVariableRequestBuilder_ != null) {
- result.method_ = stateVariableRequestBuilder_.build();
- }
- if (methodCase_ == 4 &&
- implicitGroupingKeyRequestBuilder_ != null) {
- result.method_ = implicitGroupingKeyRequestBuilder_.build();
- }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
}
-
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateRequest) {
@@ -764,7 +800,7 @@ public Builder mergeFrom(
break;
case 8: {
version_ = input.readInt32();
- bitField0_ |= 0x00000001;
+
break;
} // case 8
case 18: {
@@ -818,7 +854,6 @@ public Builder clearMethod() {
return this;
}
- private int bitField0_;
private int version_ ;
/**
@@ -835,9 +870,8 @@ public int getVersion() {
* @return This builder for chaining.
*/
public Builder setVersion(int value) {
-
+
version_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -846,7 +880,7 @@ public Builder setVersion(int value) {
* @return This builder for chaining.
*/
public Builder clearVersion() {
- bitField0_ = (bitField0_ & ~0x00000001);
+
version_ = 0;
onChanged();
return this;
@@ -990,7 +1024,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProce
method_ = null;
}
methodCase_ = 2;
- onChanged();
+ onChanged();;
return statefulProcessorCallBuilder_;
}
@@ -1132,7 +1166,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariable
method_ = null;
}
methodCase_ = 3;
- onChanged();
+ onChanged();;
return stateVariableRequestBuilder_;
}
@@ -1274,9 +1308,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroup
method_ = null;
}
methodCase_ = 4;
- onChanged();
+ onChanged();;
return implicitGroupingKeyRequestBuilder_;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateRequest)
}
@@ -1374,6 +1420,18 @@ private StateResponse() {
value_ = com.google.protobuf.ByteString.EMPTY;
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new StateResponse();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateResponse_descriptor;
@@ -1388,7 +1446,7 @@ private StateResponse() {
}
public static final int STATUSCODE_FIELD_NUMBER = 1;
- private int statusCode_ = 0;
+ private int statusCode_;
/**
* int32 statusCode = 1;
* @return The statusCode.
@@ -1399,8 +1457,7 @@ public int getStatusCode() {
}
public static final int ERRORMESSAGE_FIELD_NUMBER = 2;
- @SuppressWarnings("serial")
- private volatile java.lang.Object errorMessage_ = "";
+ private volatile java.lang.Object errorMessage_;
/**
* string errorMessage = 2;
* @return The errorMessage.
@@ -1438,7 +1495,7 @@ public java.lang.String getErrorMessage() {
}
public static final int VALUE_FIELD_NUMBER = 3;
- private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY;
+ private com.google.protobuf.ByteString value_;
/**
* bytes value = 3;
* @return The value.
@@ -1578,13 +1635,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateR
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -1659,10 +1714,12 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
statusCode_ = 0;
+
errorMessage_ = "";
+
value_ = com.google.protobuf.ByteString.EMPTY;
+
return this;
}
@@ -1689,24 +1746,45 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse(this);
- if (bitField0_ != 0) { buildPartial0(result); }
+ result.statusCode_ = statusCode_;
+ result.errorMessage_ = errorMessage_;
+ result.value_ = value_;
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse result) {
- int from_bitField0_ = bitField0_;
- if (((from_bitField0_ & 0x00000001) != 0)) {
- result.statusCode_ = statusCode_;
- }
- if (((from_bitField0_ & 0x00000002) != 0)) {
- result.errorMessage_ = errorMessage_;
- }
- if (((from_bitField0_ & 0x00000004) != 0)) {
- result.value_ = value_;
- }
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
}
-
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponse) {
@@ -1724,7 +1802,6 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes
}
if (!other.getErrorMessage().isEmpty()) {
errorMessage_ = other.errorMessage_;
- bitField0_ |= 0x00000002;
onChanged();
}
if (other.getValue() != com.google.protobuf.ByteString.EMPTY) {
@@ -1758,17 +1835,17 @@ public Builder mergeFrom(
break;
case 8: {
statusCode_ = input.readInt32();
- bitField0_ |= 0x00000001;
+
break;
} // case 8
case 18: {
errorMessage_ = input.readStringRequireUtf8();
- bitField0_ |= 0x00000002;
+
break;
} // case 18
case 26: {
value_ = input.readBytes();
- bitField0_ |= 0x00000004;
+
break;
} // case 26
default: {
@@ -1786,7 +1863,6 @@ public Builder mergeFrom(
} // finally
return this;
}
- private int bitField0_;
private int statusCode_ ;
/**
@@ -1803,9 +1879,8 @@ public int getStatusCode() {
* @return This builder for chaining.
*/
public Builder setStatusCode(int value) {
-
+
statusCode_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -1814,7 +1889,7 @@ public Builder setStatusCode(int value) {
* @return This builder for chaining.
*/
public Builder clearStatusCode() {
- bitField0_ = (bitField0_ & ~0x00000001);
+
statusCode_ = 0;
onChanged();
return this;
@@ -1861,9 +1936,11 @@ public java.lang.String getErrorMessage() {
*/
public Builder setErrorMessage(
java.lang.String value) {
- if (value == null) { throw new NullPointerException(); }
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
errorMessage_ = value;
- bitField0_ |= 0x00000002;
onChanged();
return this;
}
@@ -1872,8 +1949,8 @@ public Builder setErrorMessage(
* @return This builder for chaining.
*/
public Builder clearErrorMessage() {
+
errorMessage_ = getDefaultInstance().getErrorMessage();
- bitField0_ = (bitField0_ & ~0x00000002);
onChanged();
return this;
}
@@ -1884,10 +1961,12 @@ public Builder clearErrorMessage() {
*/
public Builder setErrorMessageBytes(
com.google.protobuf.ByteString value) {
- if (value == null) { throw new NullPointerException(); }
- checkByteStringIsUtf8(value);
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ checkByteStringIsUtf8(value);
+
errorMessage_ = value;
- bitField0_ |= 0x00000002;
onChanged();
return this;
}
@@ -1907,9 +1986,11 @@ public com.google.protobuf.ByteString getValue() {
* @return This builder for chaining.
*/
public Builder setValue(com.google.protobuf.ByteString value) {
- if (value == null) { throw new NullPointerException(); }
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
value_ = value;
- bitField0_ |= 0x00000004;
onChanged();
return this;
}
@@ -1918,11 +1999,23 @@ public Builder setValue(com.google.protobuf.ByteString value) {
* @return This builder for chaining.
*/
public Builder clearValue() {
- bitField0_ = (bitField0_ & ~0x00000004);
+
value_ = getDefaultInstance().getValue();
onChanged();
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateResponse)
}
@@ -2039,7 +2132,7 @@ public interface StatefulProcessorCallOrBuilder extends
*/
org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommandOrBuilder getGetMapStateOrBuilder();
- org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall.MethodCase getMethodCase();
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall.MethodCase getMethodCase();
}
/**
* Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StatefulProcessorCall}
@@ -2056,6 +2149,18 @@ private StatefulProcessorCall(com.google.protobuf.GeneratedMessageV3.Builder>
private StatefulProcessorCall() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new StatefulProcessorCall();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StatefulProcessorCall_descriptor;
@@ -2070,7 +2175,6 @@ private StatefulProcessorCall() {
}
private int methodCase_ = 0;
- @SuppressWarnings("serial")
private java.lang.Object method_;
public enum MethodCase
implements com.google.protobuf.Internal.EnumLite,
@@ -2406,13 +2510,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Statef
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -2487,7 +2589,6 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
if (setHandleStateBuilder_ != null) {
setHandleStateBuilder_.clear();
}
@@ -2528,37 +2629,71 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProce
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall(this);
- if (bitField0_ != 0) { buildPartial0(result); }
- buildPartialOneofs(result);
- onBuilt();
- return result;
- }
-
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall result) {
- int from_bitField0_ = bitField0_;
- }
-
- private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall result) {
- result.methodCase_ = methodCase_;
- result.method_ = this.method_;
- if (methodCase_ == 1 &&
- setHandleStateBuilder_ != null) {
- result.method_ = setHandleStateBuilder_.build();
+ if (methodCase_ == 1) {
+ if (setHandleStateBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = setHandleStateBuilder_.build();
+ }
}
- if (methodCase_ == 2 &&
- getValueStateBuilder_ != null) {
- result.method_ = getValueStateBuilder_.build();
+ if (methodCase_ == 2) {
+ if (getValueStateBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = getValueStateBuilder_.build();
+ }
}
- if (methodCase_ == 3 &&
- getListStateBuilder_ != null) {
- result.method_ = getListStateBuilder_.build();
+ if (methodCase_ == 3) {
+ if (getListStateBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = getListStateBuilder_.build();
+ }
}
- if (methodCase_ == 4 &&
- getMapStateBuilder_ != null) {
- result.method_ = getMapStateBuilder_.build();
+ if (methodCase_ == 4) {
+ if (getMapStateBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = getMapStateBuilder_.build();
+ }
}
+ result.methodCase_ = methodCase_;
+ onBuilt();
+ return result;
}
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StatefulProcessorCall) {
@@ -2676,7 +2811,6 @@ public Builder clearMethod() {
return this;
}
- private int bitField0_;
private com.google.protobuf.SingleFieldBuilderV3<
org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState, org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStateOrBuilder> setHandleStateBuilder_;
@@ -2816,7 +2950,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat
method_ = null;
}
methodCase_ = 1;
- onChanged();
+ onChanged();;
return setHandleStateBuilder_;
}
@@ -2958,7 +3092,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm
method_ = null;
}
methodCase_ = 2;
- onChanged();
+ onChanged();;
return getValueStateBuilder_;
}
@@ -3100,7 +3234,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm
method_ = null;
}
methodCase_ = 3;
- onChanged();
+ onChanged();;
return getListStateBuilder_;
}
@@ -3242,9 +3376,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm
method_ = null;
}
methodCase_ = 4;
- onChanged();
+ onChanged();;
return getMapStateBuilder_;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StatefulProcessorCall)
}
@@ -3316,7 +3462,7 @@ public interface StateVariableRequestOrBuilder extends
*/
org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCallOrBuilder getValueStateCallOrBuilder();
- org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest.MethodCase getMethodCase();
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest.MethodCase getMethodCase();
}
/**
* Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StateVariableRequest}
@@ -3333,6 +3479,18 @@ private StateVariableRequest(com.google.protobuf.GeneratedMessageV3.Builder> b
private StateVariableRequest() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new StateVariableRequest();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_descriptor;
@@ -3347,7 +3505,6 @@ private StateVariableRequest() {
}
private int methodCase_ = 0;
- @SuppressWarnings("serial")
private java.lang.Object method_;
public enum MethodCase
implements com.google.protobuf.Internal.EnumLite,
@@ -3539,13 +3696,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateV
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -3620,7 +3775,6 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
if (valueStateCallBuilder_ != null) {
valueStateCallBuilder_.clear();
}
@@ -3652,25 +3806,50 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariable
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest(this);
- if (bitField0_ != 0) { buildPartial0(result); }
- buildPartialOneofs(result);
+ if (methodCase_ == 1) {
+ if (valueStateCallBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = valueStateCallBuilder_.build();
+ }
+ }
+ result.methodCase_ = methodCase_;
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest result) {
- int from_bitField0_ = bitField0_;
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
}
-
- private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest result) {
- result.methodCase_ = methodCase_;
- result.method_ = this.method_;
- if (methodCase_ == 1 &&
- valueStateCallBuilder_ != null) {
- result.method_ = valueStateCallBuilder_.build();
- }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
}
-
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest) {
@@ -3755,7 +3934,6 @@ public Builder clearMethod() {
return this;
}
- private int bitField0_;
private com.google.protobuf.SingleFieldBuilderV3<
org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCallOrBuilder> valueStateCallBuilder_;
@@ -3895,9 +4073,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal
method_ = null;
}
methodCase_ = 1;
- onChanged();
+ onChanged();;
return valueStateCallBuilder_;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateVariableRequest)
}
@@ -3984,7 +4174,7 @@ public interface ImplicitGroupingKeyRequestOrBuilder extends
*/
org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder getRemoveImplicitKeyOrBuilder();
- org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest.MethodCase getMethodCase();
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest.MethodCase getMethodCase();
}
/**
* Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequest}
@@ -4001,6 +4191,18 @@ private ImplicitGroupingKeyRequest(com.google.protobuf.GeneratedMessageV3.Builde
private ImplicitGroupingKeyRequest() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new ImplicitGroupingKeyRequest();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_descriptor;
@@ -4015,7 +4217,6 @@ private ImplicitGroupingKeyRequest() {
}
private int methodCase_ = 0;
- @SuppressWarnings("serial")
private java.lang.Object method_;
public enum MethodCase
implements com.google.protobuf.Internal.EnumLite,
@@ -4255,13 +4456,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Implic
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -4336,7 +4535,6 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
if (setImplicitKeyBuilder_ != null) {
setImplicitKeyBuilder_.clear();
}
@@ -4371,29 +4569,57 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroup
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest(this);
- if (bitField0_ != 0) { buildPartial0(result); }
- buildPartialOneofs(result);
+ if (methodCase_ == 1) {
+ if (setImplicitKeyBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = setImplicitKeyBuilder_.build();
+ }
+ }
+ if (methodCase_ == 2) {
+ if (removeImplicitKeyBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = removeImplicitKeyBuilder_.build();
+ }
+ }
+ result.methodCase_ = methodCase_;
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest result) {
- int from_bitField0_ = bitField0_;
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
}
-
- private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest result) {
- result.methodCase_ = methodCase_;
- result.method_ = this.method_;
- if (methodCase_ == 1 &&
- setImplicitKeyBuilder_ != null) {
- result.method_ = setImplicitKeyBuilder_.build();
- }
- if (methodCase_ == 2 &&
- removeImplicitKeyBuilder_ != null) {
- result.method_ = removeImplicitKeyBuilder_.build();
- }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
}
-
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ImplicitGroupingKeyRequest) {
@@ -4489,7 +4715,6 @@ public Builder clearMethod() {
return this;
}
- private int bitField0_;
private com.google.protobuf.SingleFieldBuilderV3<
org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder> setImplicitKeyBuilder_;
@@ -4629,7 +4854,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKe
method_ = null;
}
methodCase_ = 1;
- onChanged();
+ onChanged();;
return setImplicitKeyBuilder_;
}
@@ -4771,9 +4996,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplici
method_ = null;
}
methodCase_ = 2;
- onChanged();
+ onChanged();;
return removeImplicitKeyBuilder_;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequest)
}
@@ -4853,6 +5090,21 @@ public interface StateCallCommandOrBuilder extends
*/
com.google.protobuf.ByteString
getSchemaBytes();
+
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ * @return Whether the ttl field is set.
+ */
+ boolean hasTtl();
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ * @return The ttl.
+ */
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl();
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder();
}
/**
* Protobuf type {@code org.apache.spark.sql.execution.streaming.state.StateCallCommand}
@@ -4871,6 +5123,18 @@ private StateCallCommand() {
schema_ = "";
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new StateCallCommand();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_descriptor;
@@ -4885,8 +5149,7 @@ private StateCallCommand() {
}
public static final int STATENAME_FIELD_NUMBER = 1;
- @SuppressWarnings("serial")
- private volatile java.lang.Object stateName_ = "";
+ private volatile java.lang.Object stateName_;
/**
* string stateName = 1;
* @return The stateName.
@@ -4924,8 +5187,7 @@ public java.lang.String getStateName() {
}
public static final int SCHEMA_FIELD_NUMBER = 2;
- @SuppressWarnings("serial")
- private volatile java.lang.Object schema_ = "";
+ private volatile java.lang.Object schema_;
/**
* string schema = 2;
* @return The schema.
@@ -4962,6 +5224,32 @@ public java.lang.String getSchema() {
}
}
+ public static final int TTL_FIELD_NUMBER = 3;
+ private org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig ttl_;
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ * @return Whether the ttl field is set.
+ */
+ @java.lang.Override
+ public boolean hasTtl() {
+ return ttl_ != null;
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ * @return The ttl.
+ */
+ @java.lang.Override
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl() {
+ return ttl_ == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance() : ttl_;
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ @java.lang.Override
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder() {
+ return getTtl();
+ }
+
private byte memoizedIsInitialized = -1;
@java.lang.Override
public final boolean isInitialized() {
@@ -4982,6 +5270,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output)
if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(schema_)) {
com.google.protobuf.GeneratedMessageV3.writeString(output, 2, schema_);
}
+ if (ttl_ != null) {
+ output.writeMessage(3, getTtl());
+ }
getUnknownFields().writeTo(output);
}
@@ -4997,6 +5288,10 @@ public int getSerializedSize() {
if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(schema_)) {
size += com.google.protobuf.GeneratedMessageV3.computeStringSize(2, schema_);
}
+ if (ttl_ != null) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeMessageSize(3, getTtl());
+ }
size += getUnknownFields().getSerializedSize();
memoizedSize = size;
return size;
@@ -5016,6 +5311,11 @@ public boolean equals(final java.lang.Object obj) {
.equals(other.getStateName())) return false;
if (!getSchema()
.equals(other.getSchema())) return false;
+ if (hasTtl() != other.hasTtl()) return false;
+ if (hasTtl()) {
+ if (!getTtl()
+ .equals(other.getTtl())) return false;
+ }
if (!getUnknownFields().equals(other.getUnknownFields())) return false;
return true;
}
@@ -5031,6 +5331,10 @@ public int hashCode() {
hash = (53 * hash) + getStateName().hashCode();
hash = (37 * hash) + SCHEMA_FIELD_NUMBER;
hash = (53 * hash) + getSchema().hashCode();
+ if (hasTtl()) {
+ hash = (37 * hash) + TTL_FIELD_NUMBER;
+ hash = (53 * hash) + getTtl().hashCode();
+ }
hash = (29 * hash) + getUnknownFields().hashCode();
memoizedHashCode = hash;
return hash;
@@ -5080,13 +5384,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateC
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -5161,9 +5463,16 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
stateName_ = "";
+
schema_ = "";
+
+ if (ttlBuilder_ == null) {
+ ttl_ = null;
+ } else {
+ ttl_ = null;
+ ttlBuilder_ = null;
+ }
return this;
}
@@ -5190,21 +5499,49 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallComm
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand result = new org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand(this);
- if (bitField0_ != 0) { buildPartial0(result); }
+ result.stateName_ = stateName_;
+ result.schema_ = schema_;
+ if (ttlBuilder_ == null) {
+ result.ttl_ = ttl_;
+ } else {
+ result.ttl_ = ttlBuilder_.build();
+ }
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand result) {
- int from_bitField0_ = bitField0_;
- if (((from_bitField0_ & 0x00000001) != 0)) {
- result.stateName_ = stateName_;
- }
- if (((from_bitField0_ & 0x00000002) != 0)) {
- result.schema_ = schema_;
- }
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
}
-
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand) {
@@ -5219,14 +5556,15 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes
if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.StateCallCommand.getDefaultInstance()) return this;
if (!other.getStateName().isEmpty()) {
stateName_ = other.stateName_;
- bitField0_ |= 0x00000001;
onChanged();
}
if (!other.getSchema().isEmpty()) {
schema_ = other.schema_;
- bitField0_ |= 0x00000002;
onChanged();
}
+ if (other.hasTtl()) {
+ mergeTtl(other.getTtl());
+ }
this.mergeUnknownFields(other.getUnknownFields());
onChanged();
return this;
@@ -5255,14 +5593,21 @@ public Builder mergeFrom(
break;
case 10: {
stateName_ = input.readStringRequireUtf8();
- bitField0_ |= 0x00000001;
+
break;
} // case 10
case 18: {
schema_ = input.readStringRequireUtf8();
- bitField0_ |= 0x00000002;
+
break;
} // case 18
+ case 26: {
+ input.readMessage(
+ getTtlFieldBuilder().getBuilder(),
+ extensionRegistry);
+
+ break;
+ } // case 26
default: {
if (!super.parseUnknownField(input, extensionRegistry, tag)) {
done = true; // was an endgroup tag
@@ -5278,7 +5623,6 @@ public Builder mergeFrom(
} // finally
return this;
}
- private int bitField0_;
private java.lang.Object stateName_ = "";
/**
@@ -5321,9 +5665,11 @@ public java.lang.String getStateName() {
*/
public Builder setStateName(
java.lang.String value) {
- if (value == null) { throw new NullPointerException(); }
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
stateName_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -5332,8 +5678,8 @@ public Builder setStateName(
* @return This builder for chaining.
*/
public Builder clearStateName() {
+
stateName_ = getDefaultInstance().getStateName();
- bitField0_ = (bitField0_ & ~0x00000001);
onChanged();
return this;
}
@@ -5344,10 +5690,12 @@ public Builder clearStateName() {
*/
public Builder setStateNameBytes(
com.google.protobuf.ByteString value) {
- if (value == null) { throw new NullPointerException(); }
- checkByteStringIsUtf8(value);
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ checkByteStringIsUtf8(value);
+
stateName_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -5393,9 +5741,11 @@ public java.lang.String getSchema() {
*/
public Builder setSchema(
java.lang.String value) {
- if (value == null) { throw new NullPointerException(); }
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
schema_ = value;
- bitField0_ |= 0x00000002;
onChanged();
return this;
}
@@ -5404,8 +5754,8 @@ public Builder setSchema(
* @return This builder for chaining.
*/
public Builder clearSchema() {
+
schema_ = getDefaultInstance().getSchema();
- bitField0_ = (bitField0_ & ~0x00000002);
onChanged();
return this;
}
@@ -5416,14 +5766,147 @@ public Builder clearSchema() {
*/
public Builder setSchemaBytes(
com.google.protobuf.ByteString value) {
- if (value == null) { throw new NullPointerException(); }
- checkByteStringIsUtf8(value);
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ checkByteStringIsUtf8(value);
+
schema_ = value;
- bitField0_ |= 0x00000002;
onChanged();
return this;
}
+ private org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig ttl_;
+ private com.google.protobuf.SingleFieldBuilderV3<
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder> ttlBuilder_;
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ * @return Whether the ttl field is set.
+ */
+ public boolean hasTtl() {
+ return ttlBuilder_ != null || ttl_ != null;
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ * @return The ttl.
+ */
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getTtl() {
+ if (ttlBuilder_ == null) {
+ return ttl_ == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance() : ttl_;
+ } else {
+ return ttlBuilder_.getMessage();
+ }
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ public Builder setTtl(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig value) {
+ if (ttlBuilder_ == null) {
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ ttl_ = value;
+ onChanged();
+ } else {
+ ttlBuilder_.setMessage(value);
+ }
+
+ return this;
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ public Builder setTtl(
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder builderForValue) {
+ if (ttlBuilder_ == null) {
+ ttl_ = builderForValue.build();
+ onChanged();
+ } else {
+ ttlBuilder_.setMessage(builderForValue.build());
+ }
+
+ return this;
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ public Builder mergeTtl(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig value) {
+ if (ttlBuilder_ == null) {
+ if (ttl_ != null) {
+ ttl_ =
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.newBuilder(ttl_).mergeFrom(value).buildPartial();
+ } else {
+ ttl_ = value;
+ }
+ onChanged();
+ } else {
+ ttlBuilder_.mergeFrom(value);
+ }
+
+ return this;
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ public Builder clearTtl() {
+ if (ttlBuilder_ == null) {
+ ttl_ = null;
+ onChanged();
+ } else {
+ ttl_ = null;
+ ttlBuilder_ = null;
+ }
+
+ return this;
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder getTtlBuilder() {
+
+ onChanged();
+ return getTtlFieldBuilder().getBuilder();
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder getTtlOrBuilder() {
+ if (ttlBuilder_ != null) {
+ return ttlBuilder_.getMessageOrBuilder();
+ } else {
+ return ttl_ == null ?
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance() : ttl_;
+ }
+ }
+ /**
+ * .org.apache.spark.sql.execution.streaming.state.TTLConfig ttl = 3;
+ */
+ private com.google.protobuf.SingleFieldBuilderV3<
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder>
+ getTtlFieldBuilder() {
+ if (ttlBuilder_ == null) {
+ ttlBuilder_ = new com.google.protobuf.SingleFieldBuilderV3<
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder>(
+ getTtl(),
+ getParentForChildren(),
+ isClean());
+ ttl_ = null;
+ }
+ return ttlBuilder_;
+ }
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.StateCallCommand)
}
@@ -5551,7 +6034,7 @@ public interface ValueStateCallOrBuilder extends
*/
org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder();
- org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.MethodCase getMethodCase();
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.MethodCase getMethodCase();
}
/**
* Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateCall}
@@ -5569,6 +6052,18 @@ private ValueStateCall() {
stateName_ = "";
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new ValueStateCall();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_descriptor;
@@ -5583,7 +6078,6 @@ private ValueStateCall() {
}
private int methodCase_ = 0;
- @SuppressWarnings("serial")
private java.lang.Object method_;
public enum MethodCase
implements com.google.protobuf.Internal.EnumLite,
@@ -5629,8 +6123,7 @@ public int getNumber() {
}
public static final int STATENAME_FIELD_NUMBER = 1;
- @SuppressWarnings("serial")
- private volatile java.lang.Object stateName_ = "";
+ private volatile java.lang.Object stateName_;
/**
* string stateName = 1;
* @return The stateName.
@@ -5968,13 +6461,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueS
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -6049,8 +6540,8 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
stateName_ = "";
+
if (existsBuilder_ != null) {
existsBuilder_.clear();
}
@@ -6091,40 +6582,72 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall(this);
- if (bitField0_ != 0) { buildPartial0(result); }
- buildPartialOneofs(result);
- onBuilt();
- return result;
- }
-
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall result) {
- int from_bitField0_ = bitField0_;
- if (((from_bitField0_ & 0x00000001) != 0)) {
- result.stateName_ = stateName_;
- }
- }
-
- private void buildPartialOneofs(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall result) {
- result.methodCase_ = methodCase_;
- result.method_ = this.method_;
- if (methodCase_ == 2 &&
- existsBuilder_ != null) {
- result.method_ = existsBuilder_.build();
+ result.stateName_ = stateName_;
+ if (methodCase_ == 2) {
+ if (existsBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = existsBuilder_.build();
+ }
}
- if (methodCase_ == 3 &&
- getBuilder_ != null) {
- result.method_ = getBuilder_.build();
+ if (methodCase_ == 3) {
+ if (getBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = getBuilder_.build();
+ }
}
- if (methodCase_ == 4 &&
- valueStateUpdateBuilder_ != null) {
- result.method_ = valueStateUpdateBuilder_.build();
+ if (methodCase_ == 4) {
+ if (valueStateUpdateBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = valueStateUpdateBuilder_.build();
+ }
}
- if (methodCase_ == 5 &&
- clearBuilder_ != null) {
- result.method_ = clearBuilder_.build();
+ if (methodCase_ == 5) {
+ if (clearBuilder_ == null) {
+ result.method_ = method_;
+ } else {
+ result.method_ = clearBuilder_.build();
+ }
}
+ result.methodCase_ = methodCase_;
+ onBuilt();
+ return result;
}
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall) {
@@ -6139,7 +6662,6 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes
if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.getDefaultInstance()) return this;
if (!other.getStateName().isEmpty()) {
stateName_ = other.stateName_;
- bitField0_ |= 0x00000001;
onChanged();
}
switch (other.getMethodCase()) {
@@ -6191,7 +6713,7 @@ public Builder mergeFrom(
break;
case 10: {
stateName_ = input.readStringRequireUtf8();
- bitField0_ |= 0x00000001;
+
break;
} // case 10
case 18: {
@@ -6252,7 +6774,6 @@ public Builder clearMethod() {
return this;
}
- private int bitField0_;
private java.lang.Object stateName_ = "";
/**
@@ -6295,9 +6816,11 @@ public java.lang.String getStateName() {
*/
public Builder setStateName(
java.lang.String value) {
- if (value == null) { throw new NullPointerException(); }
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
stateName_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -6306,8 +6829,8 @@ public Builder setStateName(
* @return This builder for chaining.
*/
public Builder clearStateName() {
+
stateName_ = getDefaultInstance().getStateName();
- bitField0_ = (bitField0_ & ~0x00000001);
onChanged();
return this;
}
@@ -6318,10 +6841,12 @@ public Builder clearStateName() {
*/
public Builder setStateNameBytes(
com.google.protobuf.ByteString value) {
- if (value == null) { throw new NullPointerException(); }
- checkByteStringIsUtf8(value);
+ if (value == null) {
+ throw new NullPointerException();
+ }
+ checkByteStringIsUtf8(value);
+
stateName_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -6464,7 +6989,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuild
method_ = null;
}
methodCase_ = 2;
- onChanged();
+ onChanged();;
return existsBuilder_;
}
@@ -6606,7 +7131,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder
method_ = null;
}
methodCase_ = 3;
- onChanged();
+ onChanged();;
return getBuilder_;
}
@@ -6748,7 +7273,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpd
method_ = null;
}
methodCase_ = 4;
- onChanged();
+ onChanged();;
return valueStateUpdateBuilder_;
}
@@ -6890,9 +7415,21 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilde
method_ = null;
}
methodCase_ = 5;
- onChanged();
+ onChanged();;
return clearBuilder_;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateCall)
}
@@ -6971,6 +7508,18 @@ private SetImplicitKey() {
key_ = com.google.protobuf.ByteString.EMPTY;
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new SetImplicitKey();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor;
@@ -6985,7 +7534,7 @@ private SetImplicitKey() {
}
public static final int KEY_FIELD_NUMBER = 1;
- private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY;
+ private com.google.protobuf.ByteString key_;
/**
* bytes key = 1;
* @return The key.
@@ -7104,13 +7653,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImp
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -7185,8 +7732,8 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
key_ = com.google.protobuf.ByteString.EMPTY;
+
return this;
}
@@ -7213,18 +7760,43 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKe
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this);
- if (bitField0_ != 0) { buildPartial0(result); }
+ result.key_ = key_;
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result) {
- int from_bitField0_ = bitField0_;
- if (((from_bitField0_ & 0x00000001) != 0)) {
- result.key_ = key_;
- }
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
}
-
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) {
@@ -7268,7 +7840,7 @@ public Builder mergeFrom(
break;
case 10: {
key_ = input.readBytes();
- bitField0_ |= 0x00000001;
+
break;
} // case 10
default: {
@@ -7286,7 +7858,6 @@ public Builder mergeFrom(
} // finally
return this;
}
- private int bitField0_;
private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY;
/**
@@ -7303,9 +7874,11 @@ public com.google.protobuf.ByteString getKey() {
* @return This builder for chaining.
*/
public Builder setKey(com.google.protobuf.ByteString value) {
- if (value == null) { throw new NullPointerException(); }
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
key_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -7314,11 +7887,23 @@ public Builder setKey(com.google.protobuf.ByteString value) {
* @return This builder for chaining.
*/
public Builder clearKey() {
- bitField0_ = (bitField0_ & ~0x00000001);
+
key_ = getDefaultInstance().getKey();
onChanged();
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey)
}
@@ -7390,6 +7975,18 @@ private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder> buil
private RemoveImplicitKey() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new RemoveImplicitKey();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor;
@@ -7501,13 +8098,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Remove
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -7612,6 +8207,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplici
return result;
}
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) {
@@ -7665,6 +8292,18 @@ public Builder mergeFrom(
} // finally
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey)
}
@@ -7736,6 +8375,18 @@ private Exists(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
private Exists() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new Exists();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor;
@@ -7847,13 +8498,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -7958,6 +8607,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildP
return result;
}
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) {
@@ -8011,6 +8692,18 @@ public Builder mergeFrom(
} // finally
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists)
}
@@ -8082,6 +8775,18 @@ private Get(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
private Get() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new Get();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor;
@@ -8193,13 +8898,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get pa
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -8304,6 +9007,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPart
return result;
}
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) {
@@ -8357,6 +9092,18 @@ public Builder mergeFrom(
} // finally
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get)
}
@@ -8435,6 +9182,18 @@ private ValueStateUpdate() {
value_ = com.google.protobuf.ByteString.EMPTY;
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new ValueStateUpdate();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor;
@@ -8449,7 +9208,7 @@ private ValueStateUpdate() {
}
public static final int VALUE_FIELD_NUMBER = 1;
- private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY;
+ private com.google.protobuf.ByteString value_;
/**
* bytes value = 1;
* @return The value.
@@ -8568,13 +9327,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueS
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -8649,8 +9406,8 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
value_ = com.google.protobuf.ByteString.EMPTY;
+
return this;
}
@@ -8677,20 +9434,45 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpd
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this);
- if (bitField0_ != 0) { buildPartial0(result); }
+ result.value_ = value_;
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result) {
- int from_bitField0_ = bitField0_;
- if (((from_bitField0_ & 0x00000001) != 0)) {
- result.value_ = value_;
- }
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
}
-
@java.lang.Override
- public Builder mergeFrom(com.google.protobuf.Message other) {
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
+ @java.lang.Override
+ public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) {
return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other);
} else {
@@ -8732,7 +9514,7 @@ public Builder mergeFrom(
break;
case 10: {
value_ = input.readBytes();
- bitField0_ |= 0x00000001;
+
break;
} // case 10
default: {
@@ -8750,7 +9532,6 @@ public Builder mergeFrom(
} // finally
return this;
}
- private int bitField0_;
private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY;
/**
@@ -8767,9 +9548,11 @@ public com.google.protobuf.ByteString getValue() {
* @return This builder for chaining.
*/
public Builder setValue(com.google.protobuf.ByteString value) {
- if (value == null) { throw new NullPointerException(); }
+ if (value == null) {
+ throw new NullPointerException();
+ }
+
value_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -8778,11 +9561,23 @@ public Builder setValue(com.google.protobuf.ByteString value) {
* @return This builder for chaining.
*/
public Builder clearValue() {
- bitField0_ = (bitField0_ & ~0x00000001);
+
value_ = getDefaultInstance().getValue();
onChanged();
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate)
}
@@ -8854,6 +9649,18 @@ private Clear(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
private Clear() {
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new Clear();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor;
@@ -8965,13 +9772,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -9076,6 +9881,38 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPa
return result;
}
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) {
@@ -9129,6 +9966,18 @@ public Builder mergeFrom(
} // finally
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear)
}
@@ -9212,6 +10061,18 @@ private SetHandleState() {
state_ = 0;
}
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new SetHandleState();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
public static final com.google.protobuf.Descriptors.Descriptor
getDescriptor() {
return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor;
@@ -9226,7 +10087,7 @@ private SetHandleState() {
}
public static final int STATE_FIELD_NUMBER = 1;
- private int state_ = 0;
+ private int state_;
/**
* .org.apache.spark.sql.execution.streaming.state.HandleState state = 1;
* @return The enum numeric value on the wire for state.
@@ -9239,7 +10100,8 @@ private SetHandleState() {
* @return The state.
*/
@java.lang.Override public org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState getState() {
- org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.forNumber(state_);
+ @SuppressWarnings("deprecation")
+ org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.valueOf(state_);
return result == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.UNRECOGNIZED : result;
}
@@ -9351,13 +10213,11 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetHan
return com.google.protobuf.GeneratedMessageV3
.parseWithIOException(PARSER, input, extensionRegistry);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState parseDelimitedFrom(java.io.InputStream input)
throws java.io.IOException {
return com.google.protobuf.GeneratedMessageV3
.parseDelimitedWithIOException(PARSER, input);
}
-
public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState parseDelimitedFrom(
java.io.InputStream input,
com.google.protobuf.ExtensionRegistryLite extensionRegistry)
@@ -9432,8 +10292,8 @@ private Builder(
@java.lang.Override
public Builder clear() {
super.clear();
- bitField0_ = 0;
state_ = 0;
+
return this;
}
@@ -9460,18 +10320,43 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState buildPartial() {
org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState(this);
- if (bitField0_ != 0) { buildPartial0(result); }
+ result.state_ = state_;
onBuilt();
return result;
}
- private void buildPartial0(org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState result) {
- int from_bitField0_ = bitField0_;
- if (((from_bitField0_ & 0x00000001) != 0)) {
- result.state_ = state_;
- }
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
}
-
@java.lang.Override
public Builder mergeFrom(com.google.protobuf.Message other) {
if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleState) {
@@ -9515,7 +10400,7 @@ public Builder mergeFrom(
break;
case 8: {
state_ = input.readEnum();
- bitField0_ |= 0x00000001;
+
break;
} // case 8
default: {
@@ -9533,7 +10418,6 @@ public Builder mergeFrom(
} // finally
return this;
}
- private int bitField0_;
private int state_ = 0;
/**
@@ -9549,8 +10433,8 @@ public Builder mergeFrom(
* @return This builder for chaining.
*/
public Builder setStateValue(int value) {
+
state_ = value;
- bitField0_ |= 0x00000001;
onChanged();
return this;
}
@@ -9560,7 +10444,8 @@ public Builder setStateValue(int value) {
*/
@java.lang.Override
public org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState getState() {
- org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.forNumber(state_);
+ @SuppressWarnings("deprecation")
+ org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState result = org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.valueOf(state_);
return result == null ? org.apache.spark.sql.execution.streaming.state.StateMessage.HandleState.UNRECOGNIZED : result;
}
/**
@@ -9572,7 +10457,7 @@ public Builder setState(org.apache.spark.sql.execution.streaming.state.StateMess
if (value == null) {
throw new NullPointerException();
}
- bitField0_ |= 0x00000001;
+
state_ = value.getNumber();
onChanged();
return this;
@@ -9582,11 +10467,23 @@ public Builder setState(org.apache.spark.sql.execution.streaming.state.StateMess
* @return This builder for chaining.
*/
public Builder clearState() {
- bitField0_ = (bitField0_ & ~0x00000001);
+
state_ = 0;
onChanged();
return this;
}
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
// @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetHandleState)
}
@@ -9639,6 +10536,476 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat
}
+ public interface TTLConfigOrBuilder extends
+ // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.TTLConfig)
+ com.google.protobuf.MessageOrBuilder {
+
+ /**
+ * int32 durationMs = 1;
+ * @return The durationMs.
+ */
+ int getDurationMs();
+ }
+ /**
+ * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.TTLConfig}
+ */
+ public static final class TTLConfig extends
+ com.google.protobuf.GeneratedMessageV3 implements
+ // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.TTLConfig)
+ TTLConfigOrBuilder {
+ private static final long serialVersionUID = 0L;
+ // Use TTLConfig.newBuilder() to construct.
+ private TTLConfig(com.google.protobuf.GeneratedMessageV3.Builder> builder) {
+ super(builder);
+ }
+ private TTLConfig() {
+ }
+
+ @java.lang.Override
+ @SuppressWarnings({"unused"})
+ protected java.lang.Object newInstance(
+ UnusedPrivateParameter unused) {
+ return new TTLConfig();
+ }
+
+ @java.lang.Override
+ public final com.google.protobuf.UnknownFieldSet
+ getUnknownFields() {
+ return this.unknownFields;
+ }
+ public static final com.google.protobuf.Descriptors.Descriptor
+ getDescriptor() {
+ return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor;
+ }
+
+ @java.lang.Override
+ protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internalGetFieldAccessorTable() {
+ return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable
+ .ensureFieldAccessorsInitialized(
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.class, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder.class);
+ }
+
+ public static final int DURATIONMS_FIELD_NUMBER = 1;
+ private int durationMs_;
+ /**
+ * int32 durationMs = 1;
+ * @return The durationMs.
+ */
+ @java.lang.Override
+ public int getDurationMs() {
+ return durationMs_;
+ }
+
+ private byte memoizedIsInitialized = -1;
+ @java.lang.Override
+ public final boolean isInitialized() {
+ byte isInitialized = memoizedIsInitialized;
+ if (isInitialized == 1) return true;
+ if (isInitialized == 0) return false;
+
+ memoizedIsInitialized = 1;
+ return true;
+ }
+
+ @java.lang.Override
+ public void writeTo(com.google.protobuf.CodedOutputStream output)
+ throws java.io.IOException {
+ if (durationMs_ != 0) {
+ output.writeInt32(1, durationMs_);
+ }
+ getUnknownFields().writeTo(output);
+ }
+
+ @java.lang.Override
+ public int getSerializedSize() {
+ int size = memoizedSize;
+ if (size != -1) return size;
+
+ size = 0;
+ if (durationMs_ != 0) {
+ size += com.google.protobuf.CodedOutputStream
+ .computeInt32Size(1, durationMs_);
+ }
+ size += getUnknownFields().getSerializedSize();
+ memoizedSize = size;
+ return size;
+ }
+
+ @java.lang.Override
+ public boolean equals(final java.lang.Object obj) {
+ if (obj == this) {
+ return true;
+ }
+ if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig)) {
+ return super.equals(obj);
+ }
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig other = (org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig) obj;
+
+ if (getDurationMs()
+ != other.getDurationMs()) return false;
+ if (!getUnknownFields().equals(other.getUnknownFields())) return false;
+ return true;
+ }
+
+ @java.lang.Override
+ public int hashCode() {
+ if (memoizedHashCode != 0) {
+ return memoizedHashCode;
+ }
+ int hash = 41;
+ hash = (19 * hash) + getDescriptor().hashCode();
+ hash = (37 * hash) + DURATIONMS_FIELD_NUMBER;
+ hash = (53 * hash) + getDurationMs();
+ hash = (29 * hash) + getUnknownFields().hashCode();
+ memoizedHashCode = hash;
+ return hash;
+ }
+
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ java.nio.ByteBuffer data)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ java.nio.ByteBuffer data,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data, extensionRegistry);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ com.google.protobuf.ByteString data)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ com.google.protobuf.ByteString data,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data, extensionRegistry);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(byte[] data)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ byte[] data,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ return PARSER.parseFrom(data, extensionRegistry);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(java.io.InputStream input)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ java.io.InputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input, extensionRegistry);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseDelimitedFrom(java.io.InputStream input)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseDelimitedWithIOException(PARSER, input);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseDelimitedFrom(
+ java.io.InputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseDelimitedWithIOException(PARSER, input, extensionRegistry);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ com.google.protobuf.CodedInputStream input)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input);
+ }
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig parseFrom(
+ com.google.protobuf.CodedInputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ return com.google.protobuf.GeneratedMessageV3
+ .parseWithIOException(PARSER, input, extensionRegistry);
+ }
+
+ @java.lang.Override
+ public Builder newBuilderForType() { return newBuilder(); }
+ public static Builder newBuilder() {
+ return DEFAULT_INSTANCE.toBuilder();
+ }
+ public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig prototype) {
+ return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype);
+ }
+ @java.lang.Override
+ public Builder toBuilder() {
+ return this == DEFAULT_INSTANCE
+ ? new Builder() : new Builder().mergeFrom(this);
+ }
+
+ @java.lang.Override
+ protected Builder newBuilderForType(
+ com.google.protobuf.GeneratedMessageV3.BuilderParent parent) {
+ Builder builder = new Builder(parent);
+ return builder;
+ }
+ /**
+ * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.TTLConfig}
+ */
+ public static final class Builder extends
+ com.google.protobuf.GeneratedMessageV3.Builder implements
+ // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.TTLConfig)
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfigOrBuilder {
+ public static final com.google.protobuf.Descriptors.Descriptor
+ getDescriptor() {
+ return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor;
+ }
+
+ @java.lang.Override
+ protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internalGetFieldAccessorTable() {
+ return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable
+ .ensureFieldAccessorsInitialized(
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.class, org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.Builder.class);
+ }
+
+ // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.newBuilder()
+ private Builder() {
+
+ }
+
+ private Builder(
+ com.google.protobuf.GeneratedMessageV3.BuilderParent parent) {
+ super(parent);
+
+ }
+ @java.lang.Override
+ public Builder clear() {
+ super.clear();
+ durationMs_ = 0;
+
+ return this;
+ }
+
+ @java.lang.Override
+ public com.google.protobuf.Descriptors.Descriptor
+ getDescriptorForType() {
+ return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor;
+ }
+
+ @java.lang.Override
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getDefaultInstanceForType() {
+ return org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance();
+ }
+
+ @java.lang.Override
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig build() {
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig result = buildPartial();
+ if (!result.isInitialized()) {
+ throw newUninitializedMessageException(result);
+ }
+ return result;
+ }
+
+ @java.lang.Override
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig buildPartial() {
+ org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig result = new org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig(this);
+ result.durationMs_ = durationMs_;
+ onBuilt();
+ return result;
+ }
+
+ @java.lang.Override
+ public Builder clone() {
+ return super.clone();
+ }
+ @java.lang.Override
+ public Builder setField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.setField(field, value);
+ }
+ @java.lang.Override
+ public Builder clearField(
+ com.google.protobuf.Descriptors.FieldDescriptor field) {
+ return super.clearField(field);
+ }
+ @java.lang.Override
+ public Builder clearOneof(
+ com.google.protobuf.Descriptors.OneofDescriptor oneof) {
+ return super.clearOneof(oneof);
+ }
+ @java.lang.Override
+ public Builder setRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ int index, java.lang.Object value) {
+ return super.setRepeatedField(field, index, value);
+ }
+ @java.lang.Override
+ public Builder addRepeatedField(
+ com.google.protobuf.Descriptors.FieldDescriptor field,
+ java.lang.Object value) {
+ return super.addRepeatedField(field, value);
+ }
+ @java.lang.Override
+ public Builder mergeFrom(com.google.protobuf.Message other) {
+ if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig) {
+ return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig)other);
+ } else {
+ super.mergeFrom(other);
+ return this;
+ }
+ }
+
+ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig other) {
+ if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig.getDefaultInstance()) return this;
+ if (other.getDurationMs() != 0) {
+ setDurationMs(other.getDurationMs());
+ }
+ this.mergeUnknownFields(other.getUnknownFields());
+ onChanged();
+ return this;
+ }
+
+ @java.lang.Override
+ public final boolean isInitialized() {
+ return true;
+ }
+
+ @java.lang.Override
+ public Builder mergeFrom(
+ com.google.protobuf.CodedInputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws java.io.IOException {
+ if (extensionRegistry == null) {
+ throw new java.lang.NullPointerException();
+ }
+ try {
+ boolean done = false;
+ while (!done) {
+ int tag = input.readTag();
+ switch (tag) {
+ case 0:
+ done = true;
+ break;
+ case 8: {
+ durationMs_ = input.readInt32();
+
+ break;
+ } // case 8
+ default: {
+ if (!super.parseUnknownField(input, extensionRegistry, tag)) {
+ done = true; // was an endgroup tag
+ }
+ break;
+ } // default:
+ } // switch (tag)
+ } // while (!done)
+ } catch (com.google.protobuf.InvalidProtocolBufferException e) {
+ throw e.unwrapIOException();
+ } finally {
+ onChanged();
+ } // finally
+ return this;
+ }
+
+ private int durationMs_ ;
+ /**
+ * int32 durationMs = 1;
+ * @return The durationMs.
+ */
+ @java.lang.Override
+ public int getDurationMs() {
+ return durationMs_;
+ }
+ /**
+ * int32 durationMs = 1;
+ * @param value The durationMs to set.
+ * @return This builder for chaining.
+ */
+ public Builder setDurationMs(int value) {
+
+ durationMs_ = value;
+ onChanged();
+ return this;
+ }
+ /**
+ * int32 durationMs = 1;
+ * @return This builder for chaining.
+ */
+ public Builder clearDurationMs() {
+
+ durationMs_ = 0;
+ onChanged();
+ return this;
+ }
+ @java.lang.Override
+ public final Builder setUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.setUnknownFields(unknownFields);
+ }
+
+ @java.lang.Override
+ public final Builder mergeUnknownFields(
+ final com.google.protobuf.UnknownFieldSet unknownFields) {
+ return super.mergeUnknownFields(unknownFields);
+ }
+
+
+ // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.TTLConfig)
+ }
+
+ // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.TTLConfig)
+ private static final org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig DEFAULT_INSTANCE;
+ static {
+ DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig();
+ }
+
+ public static org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getDefaultInstance() {
+ return DEFAULT_INSTANCE;
+ }
+
+ private static final com.google.protobuf.Parser
+ PARSER = new com.google.protobuf.AbstractParser() {
+ @java.lang.Override
+ public TTLConfig parsePartialFrom(
+ com.google.protobuf.CodedInputStream input,
+ com.google.protobuf.ExtensionRegistryLite extensionRegistry)
+ throws com.google.protobuf.InvalidProtocolBufferException {
+ Builder builder = newBuilder();
+ try {
+ builder.mergeFrom(input, extensionRegistry);
+ } catch (com.google.protobuf.InvalidProtocolBufferException e) {
+ throw e.setUnfinishedMessage(builder.buildPartial());
+ } catch (com.google.protobuf.UninitializedMessageException e) {
+ throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial());
+ } catch (java.io.IOException e) {
+ throw new com.google.protobuf.InvalidProtocolBufferException(e)
+ .setUnfinishedMessage(builder.buildPartial());
+ }
+ return builder.buildPartial();
+ }
+ };
+
+ public static com.google.protobuf.Parser parser() {
+ return PARSER;
+ }
+
+ @java.lang.Override
+ public com.google.protobuf.Parser getParserForType() {
+ return PARSER;
+ }
+
+ @java.lang.Override
+ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig getDefaultInstanceForType() {
+ return DEFAULT_INSTANCE;
+ }
+
+ }
+
private static final com.google.protobuf.Descriptors.Descriptor
internal_static_org_apache_spark_sql_execution_streaming_state_StateRequest_descriptor;
private static final
@@ -9709,6 +11076,11 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat
private static final
com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_fieldAccessorTable;
+ private static final com.google.protobuf.Descriptors.Descriptor
+ internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor;
+ private static final
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable
+ internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable;
public static com.google.protobuf.Descriptors.FileDescriptor
getDescriptor() {
@@ -9749,24 +11121,27 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat
"ming.state.SetImplicitKeyH\000\022^\n\021removeImp" +
"licitKey\030\002 \001(\0132A.org.apache.spark.sql.ex" +
"ecution.streaming.state.RemoveImplicitKe" +
- "yH\000B\010\n\006method\"5\n\020StateCallCommand\022\021\n\tsta" +
- "teName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\"\341\002\n\016ValueSt" +
- "ateCall\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001" +
- "(\01326.org.apache.spark.sql.execution.stre" +
- "aming.state.ExistsH\000\022B\n\003get\030\003 \001(\01323.org." +
- "apache.spark.sql.execution.streaming.sta" +
- "te.GetH\000\022\\\n\020valueStateUpdate\030\004 \001(\0132@.org" +
+ "yH\000B\010\n\006method\"}\n\020StateCallCommand\022\021\n\tsta" +
+ "teName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\022F\n\003ttl\030\003 \001(" +
+ "\01329.org.apache.spark.sql.execution.strea" +
+ "ming.state.TTLConfig\"\341\002\n\016ValueStateCall\022" +
+ "\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326.org" +
".apache.spark.sql.execution.streaming.st" +
- "ate.ValueStateUpdateH\000\022F\n\005clear\030\005 \001(\01325." +
- "org.apache.spark.sql.execution.streaming" +
- ".state.ClearH\000B\010\n\006method\"\035\n\016SetImplicitK" +
- "ey\022\013\n\003key\030\001 \001(\014\"\023\n\021RemoveImplicitKey\"\010\n\006" +
- "Exists\"\005\n\003Get\"!\n\020ValueStateUpdate\022\r\n\005val" +
- "ue\030\001 \001(\014\"\007\n\005Clear\"\\\n\016SetHandleState\022J\n\005s" +
- "tate\030\001 \001(\0162;.org.apache.spark.sql.execut" +
- "ion.streaming.state.HandleState*K\n\013Handl" +
- "eState\022\013\n\007CREATED\020\000\022\017\n\013INITIALIZED\020\001\022\022\n\016" +
- "DATA_PROCESSED\020\002\022\n\n\006CLOSED\020\003b\006proto3"
+ "ate.ExistsH\000\022B\n\003get\030\003 \001(\01323.org.apache.s" +
+ "park.sql.execution.streaming.state.GetH\000" +
+ "\022\\\n\020valueStateUpdate\030\004 \001(\0132@.org.apache." +
+ "spark.sql.execution.streaming.state.Valu" +
+ "eStateUpdateH\000\022F\n\005clear\030\005 \001(\01325.org.apac" +
+ "he.spark.sql.execution.streaming.state.C" +
+ "learH\000B\010\n\006method\"\035\n\016SetImplicitKey\022\013\n\003ke" +
+ "y\030\001 \001(\014\"\023\n\021RemoveImplicitKey\"\010\n\006Exists\"\005" +
+ "\n\003Get\"!\n\020ValueStateUpdate\022\r\n\005value\030\001 \001(\014" +
+ "\"\007\n\005Clear\"\\\n\016SetHandleState\022J\n\005state\030\001 \001" +
+ "(\0162;.org.apache.spark.sql.execution.stre" +
+ "aming.state.HandleState\"\037\n\tTTLConfig\022\022\n\n" +
+ "durationMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREAT" +
+ "ED\020\000\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020" +
+ "\002\022\n\n\006CLOSED\020\003b\006proto3"
};
descriptor = com.google.protobuf.Descriptors.FileDescriptor
.internalBuildGeneratedFileFrom(descriptorData,
@@ -9807,7 +11182,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat
internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_fieldAccessorTable = new
com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
internal_static_org_apache_spark_sql_execution_streaming_state_StateCallCommand_descriptor,
- new java.lang.String[] { "StateName", "Schema", });
+ new java.lang.String[] { "StateName", "Schema", "Ttl", });
internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_descriptor =
getDescriptor().getMessageTypes().get(6);
internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_fieldAccessorTable = new
@@ -9856,6 +11231,12 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetHandleStat
com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor,
new java.lang.String[] { "State", });
+ internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor =
+ getDescriptor().getMessageTypes().get(14);
+ internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable = new
+ com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
+ internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor,
+ new java.lang.String[] { "DurationMs", });
}
// @@protoc_insertion_point(outer_class_scope)
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 8b2fc7f5db31e..10594d6c5d340 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -72,7 +72,7 @@ public void reset() {
numNulls = 0;
}
- if (hugeVectorThreshold > 0 && capacity > hugeVectorThreshold) {
+ if (hugeVectorThreshold > -1 && capacity > hugeVectorThreshold) {
capacity = defaultCapacity;
releaseMemory();
reserveInternal(capacity);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 53640f513fc81..b356751083fc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -21,6 +21,7 @@ import java.{lang => jl}
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.ExpressionUtils.column
@@ -33,7 +34,7 @@ import org.apache.spark.sql.types._
*/
@Stable
final class DataFrameNaFunctions private[sql](df: DataFrame)
- extends api.DataFrameNaFunctions[Dataset] {
+ extends api.DataFrameNaFunctions {
import df.sparkSession.RichColumn
protected def drop(minNonNulls: Option[Int]): Dataset[Row] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 9d7a765a24c92..78cc65bb7a298 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -24,14 +24,14 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.Partition
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser}
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, FailureSafeParser}
+import org.apache.spark.sql.catalyst.util.FailureSafeParser
import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlOptions}
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
@@ -54,136 +54,44 @@ import org.apache.spark.unsafe.types.UTF8String
* @since 1.4.0
*/
@Stable
-class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
+class DataFrameReader private[sql](sparkSession: SparkSession)
+ extends api.DataFrameReader {
+ override type DS[U] = Dataset[U]
- /**
- * Specifies the input data source format.
- *
- * @since 1.4.0
- */
- def format(source: String): DataFrameReader = {
- this.source = source
- this
- }
+ format(sparkSession.sessionState.conf.defaultDataSourceName)
- /**
- * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema
- * automatically from data. By specifying the schema here, the underlying data source can
- * skip the schema inference step, and thus speed up data loading.
- *
- * @since 1.4.0
- */
- def schema(schema: StructType): DataFrameReader = {
- if (schema != null) {
- val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
- this.userSpecifiedSchema = Option(replaced)
- validateSingleVariantColumn()
- }
- this
- }
+ /** @inheritdoc */
+ override def format(source: String): this.type = super.format(source)
- /**
- * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can
- * infer the input schema automatically from data. By specifying the schema here, the underlying
- * data source can skip the schema inference step, and thus speed up data loading.
- *
- * {{{
- * spark.read.schema("a INT, b STRING, c DOUBLE").csv("test.csv")
- * }}}
- *
- * @since 2.3.0
- */
- def schema(schemaString: String): DataFrameReader = {
- schema(StructType.fromDDL(schemaString))
- }
+ /** @inheritdoc */
+ override def schema(schema: StructType): this.type = super.schema(schema)
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
- *
- * @since 1.4.0
- */
- def option(key: String, value: String): DataFrameReader = {
- this.extraOptions = this.extraOptions + (key -> value)
- validateSingleVariantColumn()
- this
- }
+ /** @inheritdoc */
+ override def schema(schemaString: String): this.type = super.schema(schemaString)
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
- *
- * @since 2.0.0
- */
- def option(key: String, value: Boolean): DataFrameReader = option(key, value.toString)
+ /** @inheritdoc */
+ override def option(key: String, value: String): this.type = super.option(key, value)
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
- *
- * @since 2.0.0
- */
- def option(key: String, value: Long): DataFrameReader = option(key, value.toString)
+ /** @inheritdoc */
+ override def option(key: String, value: Boolean): this.type = super.option(key, value)
- /**
- * Adds an input option for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
- *
- * @since 2.0.0
- */
- def option(key: String, value: Double): DataFrameReader = option(key, value.toString)
+ /** @inheritdoc */
+ override def option(key: String, value: Long): this.type = super.option(key, value)
- /**
- * (Scala-specific) Adds input options for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
- *
- * @since 1.4.0
- */
- def options(options: scala.collection.Map[String, String]): DataFrameReader = {
- this.extraOptions ++= options
- validateSingleVariantColumn()
- this
- }
+ /** @inheritdoc */
+ override def option(key: String, value: Double): this.type = super.option(key, value)
- /**
- * Adds input options for the underlying data source.
- *
- * All options are maintained in a case-insensitive way in terms of key names.
- * If a new option has the same key case-insensitively, it will override the existing option.
- *
- * @since 1.4.0
- */
- def options(options: java.util.Map[String, String]): DataFrameReader = {
- this.options(options.asScala)
- this
- }
+ /** @inheritdoc */
+ override def options(options: scala.collection.Map[String, String]): this.type =
+ super.options(options)
- /**
- * Loads input in as a `DataFrame`, for data sources that don't require a path (e.g. external
- * key-value stores).
- *
- * @since 1.4.0
- */
- def load(): DataFrame = {
- load(Seq.empty: _*) // force invocation of `load(...varargs...)`
- }
+ /** @inheritdoc */
+ override def options(options: java.util.Map[String, String]): this.type = super.options(options)
- /**
- * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by
- * a local or distributed file system).
- *
- * @since 1.4.0
- */
+ /** @inheritdoc */
+ override def load(): DataFrame = load(Nil: _*)
+
+ /** @inheritdoc */
def load(path: String): DataFrame = {
// force invocation of `load(...varargs...)`
if (sparkSession.sessionState.conf.legacyPathOptionBehavior) {
@@ -193,12 +101,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
}
- /**
- * Loads input in as a `DataFrame`, for data sources that support multiple paths.
- * Only works if the source is a HadoopFsRelationProvider.
- *
- * @since 1.6.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
def load(paths: String*): DataFrame = {
if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
@@ -235,90 +138,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
options = finalOptions.originalMap).resolveRelation())
}
- /**
- * Construct a `DataFrame` representing the database table accessible via JDBC URL
- * url named table and connection properties.
- *
- * You can find the JDBC-specific option and parameter documentation for reading tables
- * via JDBC in
- *
- * Data Source Option in the version you use.
- *
- * @since 1.4.0
- */
- def jdbc(url: String, table: String, properties: Properties): DataFrame = {
- assertNoSpecifiedSchema("jdbc")
- // properties should override settings in extraOptions.
- this.extraOptions ++= properties.asScala
- // explicit url and dbtable should override all
- this.extraOptions ++= Seq(JDBCOptions.JDBC_URL -> url, JDBCOptions.JDBC_TABLE_NAME -> table)
- format("jdbc").load()
- }
+ /** @inheritdoc */
+ override def jdbc(url: String, table: String, properties: Properties): DataFrame =
+ super.jdbc(url, table, properties)
- // scalastyle:off line.size.limit
- /**
- * Construct a `DataFrame` representing the database table accessible via JDBC URL
- * url named table. Partitions of the table will be retrieved in parallel based on the parameters
- * passed to this function.
- *
- * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
- * your external database systems.
- *
- * You can find the JDBC-specific option and parameter documentation for reading tables via JDBC in
- *
- * Data Source Option in the version you use.
- *
- * @param table Name of the table in the external database.
- * @param columnName Alias of `partitionColumn` option. Refer to `partitionColumn` in
- *
- * Data Source Option in the version you use.
- * @param connectionProperties JDBC database connection arguments, a list of arbitrary string
- * tag/value. Normally at least a "user" and "password" property
- * should be included. "fetchsize" can be used to control the
- * number of rows per fetch and "queryTimeout" can be used to wait
- * for a Statement object to execute to the given number of seconds.
- * @since 1.4.0
- */
- // scalastyle:on line.size.limit
- def jdbc(
+ /** @inheritdoc */
+ override def jdbc(
url: String,
table: String,
columnName: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
- connectionProperties: Properties): DataFrame = {
- // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions.
- this.extraOptions ++= Map(
- JDBCOptions.JDBC_PARTITION_COLUMN -> columnName,
- JDBCOptions.JDBC_LOWER_BOUND -> lowerBound.toString,
- JDBCOptions.JDBC_UPPER_BOUND -> upperBound.toString,
- JDBCOptions.JDBC_NUM_PARTITIONS -> numPartitions.toString)
- jdbc(url, table, connectionProperties)
- }
+ connectionProperties: Properties): DataFrame =
+ super.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, connectionProperties)
- /**
- * Construct a `DataFrame` representing the database table accessible via JDBC URL
- * url named table using connection properties. The `predicates` parameter gives a list
- * expressions suitable for inclusion in WHERE clauses; each one defines one partition
- * of the `DataFrame`.
- *
- * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
- * your external database systems.
- *
- * You can find the JDBC-specific option and parameter documentation for reading tables
- * via JDBC in
- *
- * Data Source Option in the version you use.
- *
- * @param table Name of the table in the external database.
- * @param predicates Condition in the where clause for each partition.
- * @param connectionProperties JDBC database connection arguments, a list of arbitrary string
- * tag/value. Normally at least a "user" and "password" property
- * should be included. "fetchsize" can be used to control the
- * number of rows per fetch.
- * @since 1.4.0
- */
+ /** @inheritdoc */
def jdbc(
url: String,
table: String,
@@ -335,38 +170,12 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
sparkSession.baseRelationToDataFrame(relation)
}
- /**
- * Loads a JSON file and returns the results as a `DataFrame`.
- *
- * See the documentation on the overloaded `json()` method with varargs for more details.
- *
- * @since 1.4.0
- */
- def json(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- json(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def json(path: String): DataFrame = super.json(path)
- /**
- * Loads JSON files and returns the results as a `DataFrame`.
- *
- * JSON Lines (newline-delimited JSON) is supported by
- * default. For JSON (one record per file), set the `multiLine` option to true.
- *
- * This function goes through the input once to determine the input schema. If you know the
- * schema in advance, use the version that specifies the schema to avoid the extra scan.
- *
- * You can find the JSON-specific options for reading JSON files in
- *
- * Data Source Option in the version you use.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def json(paths: String*): DataFrame = {
- userSpecifiedSchema.foreach(checkJsonSchema)
- format("json").load(paths : _*)
- }
+ override def json(paths: String*): DataFrame = super.json(paths: _*)
/**
* Loads a `JavaRDD[String]` storing JSON objects (JSON
@@ -397,16 +206,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
json(sparkSession.createDataset(jsonRDD)(Encoders.STRING))
}
- /**
- * Loads a `Dataset[String]` storing JSON objects (JSON Lines
- * text format or newline-delimited JSON) and returns the result as a `DataFrame`.
- *
- * Unless the schema is specified using `schema` function, this function goes through the
- * input once to determine the input schema.
- *
- * @param jsonDataset input Dataset with one JSON object per record
- * @since 2.2.0
- */
+ /** @inheritdoc */
def json(jsonDataset: Dataset[String]): DataFrame = {
val parsedOptions = new JSONOptions(
extraOptions.toMap,
@@ -439,36 +239,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming)
}
- /**
- * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the
- * other overloaded `csv()` method for more details.
- *
- * @since 2.0.0
- */
- def csv(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- csv(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def csv(path: String): DataFrame = super.csv(path)
- /**
- * Loads an `Dataset[String]` storing CSV rows and returns the result as a `DataFrame`.
- *
- * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
- * this function goes through the input once to determine the input schema.
- *
- * If the schema is not specified using `schema` function and `inferSchema` option is disabled,
- * it determines the columns as string types and it reads only the first line to determine the
- * names and the number of fields.
- *
- * If the enforceSchema is set to `false`, only the CSV header in the first line is checked
- * to conform specified or inferred schema.
- *
- * @note if `header` option is set to `true` when calling this API, all lines same with
- * the header will be removed if exists.
- *
- * @param csvDataset input Dataset with one CSV row per record
- * @since 2.2.0
- */
+ /** @inheritdoc */
def csv(csvDataset: Dataset[String]): DataFrame = {
val parsedOptions: CSVOptions = new CSVOptions(
extraOptions.toMap,
@@ -527,61 +301,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming)
}
- /**
- * Loads CSV files and returns the result as a `DataFrame`.
- *
- * This function will go through the input once to determine the input schema if `inferSchema`
- * is enabled. To avoid going through the entire data once, disable `inferSchema` option or
- * specify the schema explicitly using `schema`.
- *
- * You can find the CSV-specific options for reading CSV files in
- *
- * Data Source Option in the version you use.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def csv(paths: String*): DataFrame = format("csv").load(paths : _*)
+ override def csv(paths: String*): DataFrame = super.csv(paths: _*)
- /**
- * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the
- * other overloaded `xml()` method for more details.
- *
- * @since 4.0.0
- */
- def xml(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- xml(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def xml(path: String): DataFrame = super.xml(path)
- /**
- * Loads XML files and returns the result as a `DataFrame`.
- *
- * This function will go through the input once to determine the input schema if `inferSchema`
- * is enabled. To avoid going through the entire data once, disable `inferSchema` option or
- * specify the schema explicitly using `schema`.
- *
- * You can find the XML-specific options for reading XML files in
- *
- * Data Source Option in the version you use.
- *
- * @since 4.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def xml(paths: String*): DataFrame = {
- userSpecifiedSchema.foreach(checkXmlSchema)
- format("xml").load(paths: _*)
- }
+ override def xml(paths: String*): DataFrame = super.xml(paths: _*)
- /**
- * Loads an `Dataset[String]` storing XML object and returns the result as a `DataFrame`.
- *
- * If the schema is not specified using `schema` function and `inferSchema` option is enabled,
- * this function goes through the input once to determine the input schema.
- *
- * @param xmlDataset input Dataset with one XML object per record
- * @since 4.0.0
- */
+ /** @inheritdoc */
def xml(xmlDataset: Dataset[String]): DataFrame = {
val parsedOptions: XmlOptions = new XmlOptions(
extraOptions.toMap,
@@ -614,70 +345,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = xmlDataset.isStreaming)
}
- /**
- * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation
- * on the other overloaded `parquet()` method for more details.
- *
- * @since 2.0.0
- */
- def parquet(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- parquet(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def parquet(path: String): DataFrame = super.parquet(path)
- /**
- * Loads a Parquet file, returning the result as a `DataFrame`.
- *
- * Parquet-specific option(s) for reading Parquet files can be found in
- *
- * Data Source Option in the version you use.
- *
- * @since 1.4.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def parquet(paths: String*): DataFrame = {
- format("parquet").load(paths: _*)
- }
+ override def parquet(paths: String*): DataFrame = super.parquet(paths: _*)
- /**
- * Loads an ORC file and returns the result as a `DataFrame`.
- *
- * @param path input path
- * @since 1.5.0
- */
- def orc(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- orc(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def orc(path: String): DataFrame = super.orc(path)
- /**
- * Loads ORC files and returns the result as a `DataFrame`.
- *
- * ORC-specific option(s) for reading ORC files can be found in
- *
- * Data Source Option in the version you use.
- *
- * @param paths input paths
- * @since 2.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def orc(paths: String*): DataFrame = format("orc").load(paths: _*)
+ override def orc(paths: String*): DataFrame = super.orc(paths: _*)
- /**
- * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch
- * reading and the returned DataFrame is the batch scan query plan of this table. If it's a view,
- * the returned DataFrame is simply the query plan of the view, which can either be a batch or
- * streaming query plan.
- *
- * @param tableName is either a qualified or unqualified name that designates a table or view.
- * If a database is specified, it identifies the table/view from the database.
- * Otherwise, it first attempts to find a temporary view with the given name
- * and then match the table/view from the current database.
- * Note that, the global temporary view database is also valid here.
- * @since 1.4.0
- */
+ /** @inheritdoc */
def table(tableName: String): DataFrame = {
assertNoSpecifiedSchema("table")
val multipartIdentifier =
@@ -686,108 +368,31 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
new CaseInsensitiveStringMap(extraOptions.toMap.asJava)))
}
- /**
- * Loads text files and returns a `DataFrame` whose schema starts with a string column named
- * "value", and followed by partitioned columns if there are any. See the documentation on
- * the other overloaded `text()` method for more details.
- *
- * @since 2.0.0
- */
- def text(path: String): DataFrame = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- text(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def text(path: String): DataFrame = super.text(path)
- /**
- * Loads text files and returns a `DataFrame` whose schema starts with a string column named
- * "value", and followed by partitioned columns if there are any.
- * The text files must be encoded as UTF-8.
- *
- * By default, each line in the text files is a new row in the resulting DataFrame. For example:
- * {{{
- * // Scala:
- * spark.read.text("/path/to/spark/README.md")
- *
- * // Java:
- * spark.read().text("/path/to/spark/README.md")
- * }}}
- *
- * You can find the text-specific options for reading text files in
- *
- * Data Source Option in the version you use.
- *
- * @param paths input paths
- * @since 1.6.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def text(paths: String*): DataFrame = format("text").load(paths : _*)
+ override def text(paths: String*): DataFrame = super.text(paths: _*)
- /**
- * Loads text files and returns a [[Dataset]] of String. See the documentation on the
- * other overloaded `textFile()` method for more details.
- * @since 2.0.0
- */
- def textFile(path: String): Dataset[String] = {
- // This method ensures that calls that explicit need single argument works, see SPARK-16009
- textFile(Seq(path): _*)
- }
+ /** @inheritdoc */
+ override def textFile(path: String): Dataset[String] = super.textFile(path)
- /**
- * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset
- * contains a single string column named "value".
- * The text files must be encoded as UTF-8.
- *
- * If the directory structure of the text files contains partitioning information, those are
- * ignored in the resulting Dataset. To include partitioning information as columns, use `text`.
- *
- * By default, each line in the text files is a new row in the resulting DataFrame. For example:
- * {{{
- * // Scala:
- * spark.read.textFile("/path/to/spark/README.md")
- *
- * // Java:
- * spark.read().textFile("/path/to/spark/README.md")
- * }}}
- *
- * You can set the text-specific options as specified in `DataFrameReader.text`.
- *
- * @param paths input path
- * @since 2.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
- def textFile(paths: String*): Dataset[String] = {
- assertNoSpecifiedSchema("textFile")
- text(paths : _*).select("value").as[String](sparkSession.implicits.newStringEncoder)
- }
+ override def textFile(paths: String*): Dataset[String] = super.textFile(paths: _*)
- /**
- * A convenient function for schema validation in APIs.
- */
- private def assertNoSpecifiedSchema(operation: String): Unit = {
- if (userSpecifiedSchema.nonEmpty) {
- throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation)
- }
- }
-
- /**
- * Ensure that the `singleVariantColumn` option cannot be used if there is also a user specified
- * schema.
- */
- private def validateSingleVariantColumn(): Unit = {
+ /** @inheritdoc */
+ override protected def validateSingleVariantColumn(): Unit = {
if (extraOptions.get(JSONOptions.SINGLE_VARIANT_COLUMN).isDefined &&
userSpecifiedSchema.isDefined) {
throw QueryCompilationErrors.invalidSingleVariantColumn()
}
}
- ///////////////////////////////////////////////////////////////////////////////////////
- // Builder pattern config options
- ///////////////////////////////////////////////////////////////////////////////////////
-
- private var source: String = sparkSession.sessionState.conf.defaultDataSourceName
-
- private var userSpecifiedSchema: Option[StructType] = None
-
- private var extraOptions = CaseInsensitiveMap[String](Map.empty)
+ override protected def validateJsonSchema(): Unit =
+ userSpecifiedSchema.foreach(checkJsonSchema)
+ override protected def validateXmlSchema(): Unit =
+ userSpecifiedSchema.foreach(checkXmlSchema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index a5ab237bb7041..9f7180d8dfd6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.functions.col
import org.apache.spark.util.ArrayImplicits._
@@ -34,7 +35,7 @@ import org.apache.spark.util.ArrayImplicits._
*/
@Stable
final class DataFrameStatFunctions private[sql](protected val df: DataFrame)
- extends api.DataFrameStatFunctions[Dataset] {
+ extends api.DataFrameStatFunctions {
/** @inheritdoc */
def approxQuantile(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
deleted file mode 100644
index 576d8276b56ef..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
+++ /dev/null
@@ -1,404 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import scala.collection.mutable
-import scala.jdk.CollectionConverters._
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal}
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec}
-import org.apache.spark.sql.connector.catalog.TableWritePrivilege._
-import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, LogicalExpressions, NamedReference, Transform}
-import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.types.IntegerType
-
-/**
- * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API.
- *
- * @since 3.0.0
- */
-@Experimental
-final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
- extends CreateTableWriter[T] {
-
- private val df: DataFrame = ds.toDF()
-
- private val sparkSession = ds.sparkSession
- import sparkSession.expression
-
- private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
-
- private val logicalPlan = df.queryExecution.logical
-
- private var provider: Option[String] = None
-
- private val options = new mutable.HashMap[String, String]()
-
- private val properties = new mutable.HashMap[String, String]()
-
- private var partitioning: Option[Seq[Transform]] = None
-
- private var clustering: Option[ClusterByTransform] = None
-
- override def using(provider: String): CreateTableWriter[T] = {
- this.provider = Some(provider)
- this
- }
-
- override def option(key: String, value: String): DataFrameWriterV2[T] = {
- this.options.put(key, value)
- this
- }
-
- override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = {
- options.foreach {
- case (key, value) =>
- this.options.put(key, value)
- }
- this
- }
-
- override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = {
- this.options(options.asScala)
- this
- }
-
- override def tableProperty(property: String, value: String): CreateTableWriter[T] = {
- this.properties.put(property, value)
- this
- }
-
- @scala.annotation.varargs
- override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = {
- def ref(name: String): NamedReference = LogicalExpressions.parseReference(name)
-
- val asTransforms = (column +: columns).map(expression).map {
- case PartitionTransform.YEARS(Seq(attr: Attribute)) =>
- LogicalExpressions.years(ref(attr.name))
- case PartitionTransform.MONTHS(Seq(attr: Attribute)) =>
- LogicalExpressions.months(ref(attr.name))
- case PartitionTransform.DAYS(Seq(attr: Attribute)) =>
- LogicalExpressions.days(ref(attr.name))
- case PartitionTransform.HOURS(Seq(attr: Attribute)) =>
- LogicalExpressions.hours(ref(attr.name))
- case PartitionTransform.BUCKET(Seq(Literal(numBuckets: Int, IntegerType), attr: Attribute)) =>
- LogicalExpressions.bucket(numBuckets, Array(ref(attr.name)))
- case PartitionTransform.BUCKET(Seq(numBuckets, e)) =>
- throw QueryCompilationErrors.invalidBucketsNumberError(numBuckets.toString, e.toString)
- case attr: Attribute =>
- LogicalExpressions.identity(ref(attr.name))
- case expr =>
- throw QueryCompilationErrors.invalidPartitionTransformationError(expr)
- }
-
- this.partitioning = Some(asTransforms)
- validatePartitioning()
- this
- }
-
- @scala.annotation.varargs
- override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = {
- this.clustering =
- Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col))))
- validatePartitioning()
- this
- }
-
- /**
- * Validate that clusterBy is not used with partitionBy.
- */
- private def validatePartitioning(): Unit = {
- if (partitioning.nonEmpty && clustering.nonEmpty) {
- throw QueryCompilationErrors.clusterByWithPartitionedBy()
- }
- }
-
- override def create(): Unit = {
- val tableSpec = UnresolvedTableSpec(
- properties = properties.toMap,
- provider = provider,
- optionExpression = OptionList(Seq.empty),
- location = None,
- comment = None,
- serde = None,
- external = false)
- runCommand(
- CreateTableAsSelect(
- UnresolvedIdentifier(tableName),
- partitioning.getOrElse(Seq.empty) ++ clustering,
- logicalPlan,
- tableSpec,
- options.toMap,
- false))
- }
-
- override def replace(): Unit = {
- internalReplace(orCreate = false)
- }
-
- override def createOrReplace(): Unit = {
- internalReplace(orCreate = true)
- }
-
-
- /**
- * Append the contents of the data frame to the output table.
- *
- * If the output table does not exist, this operation will fail with
- * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
- * validated to ensure it is compatible with the existing table.
- *
- * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist
- */
- @throws(classOf[NoSuchTableException])
- def append(): Unit = {
- val append = AppendData.byName(
- UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
- logicalPlan, options.toMap)
- runCommand(append)
- }
-
- /**
- * Overwrite rows matching the given filter condition with the contents of the data frame in
- * the output table.
- *
- * If the output table does not exist, this operation will fail with
- * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]].
- * The data frame will be validated to ensure it is compatible with the existing table.
- *
- * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist
- */
- @throws(classOf[NoSuchTableException])
- def overwrite(condition: Column): Unit = {
- val overwrite = OverwriteByExpression.byName(
- UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
- logicalPlan, expression(condition), options.toMap)
- runCommand(overwrite)
- }
-
- /**
- * Overwrite all partition for which the data frame contains at least one row with the contents
- * of the data frame in the output table.
- *
- * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces
- * partitions dynamically depending on the contents of the data frame.
- *
- * If the output table does not exist, this operation will fail with
- * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
- * validated to ensure it is compatible with the existing table.
- *
- * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException If the table does not exist
- */
- @throws(classOf[NoSuchTableException])
- def overwritePartitions(): Unit = {
- val dynamicOverwrite = OverwritePartitionsDynamic.byName(
- UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
- logicalPlan, options.toMap)
- runCommand(dynamicOverwrite)
- }
-
- /**
- * Wrap an action to track the QueryExecution and time cost, then report to the user-registered
- * callback functions.
- */
- private def runCommand(command: LogicalPlan): Unit = {
- val qe = new QueryExecution(sparkSession, command, df.queryExecution.tracker)
- qe.assertCommandExecuted()
- }
-
- private def internalReplace(orCreate: Boolean): Unit = {
- val tableSpec = UnresolvedTableSpec(
- properties = properties.toMap,
- provider = provider,
- optionExpression = OptionList(Seq.empty),
- location = None,
- comment = None,
- serde = None,
- external = false)
- runCommand(ReplaceTableAsSelect(
- UnresolvedIdentifier(tableName),
- partitioning.getOrElse(Seq.empty) ++ clustering,
- logicalPlan,
- tableSpec,
- writeOptions = options.toMap,
- orCreate = orCreate))
- }
-}
-
-private object PartitionTransform {
- class ExtractTransform(name: String) {
- private val NAMES = Seq(name)
-
- def unapply(e: Expression): Option[Seq[Expression]] = e match {
- case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children)
- case _ => None
- }
- }
-
- val HOURS = new ExtractTransform("hours")
- val DAYS = new ExtractTransform("days")
- val MONTHS = new ExtractTransform("months")
- val YEARS = new ExtractTransform("years")
- val BUCKET = new ExtractTransform("bucket")
-}
-
-/**
- * Configuration methods common to create/replace operations and insert/overwrite operations.
- * @tparam R builder type to return
- * @since 3.0.0
- */
-trait WriteConfigMethods[R] {
- /**
- * Add a write option.
- *
- * @since 3.0.0
- */
- def option(key: String, value: String): R
-
- /**
- * Add a boolean output option.
- *
- * @since 3.0.0
- */
- def option(key: String, value: Boolean): R = option(key, value.toString)
-
- /**
- * Add a long output option.
- *
- * @since 3.0.0
- */
- def option(key: String, value: Long): R = option(key, value.toString)
-
- /**
- * Add a double output option.
- *
- * @since 3.0.0
- */
- def option(key: String, value: Double): R = option(key, value.toString)
-
- /**
- * Add write options from a Scala Map.
- *
- * @since 3.0.0
- */
- def options(options: scala.collection.Map[String, String]): R
-
- /**
- * Add write options from a Java Map.
- *
- * @since 3.0.0
- */
- def options(options: java.util.Map[String, String]): R
-}
-
-/**
- * Trait to restrict calls to create and replace operations.
- *
- * @since 3.0.0
- */
-trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
- /**
- * Create a new table from the contents of the data frame.
- *
- * The new table's schema, partition layout, properties, and other configuration will be
- * based on the configuration set on this writer.
- *
- * If the output table exists, this operation will fail with
- * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]].
- *
- * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
- * If the table already exists
- */
- @throws(classOf[TableAlreadyExistsException])
- def create(): Unit
-
- /**
- * Replace an existing table with the contents of the data frame.
- *
- * The existing table's schema, partition layout, properties, and other configuration will be
- * replaced with the contents of the data frame and the configuration set on this writer.
- *
- * If the output table does not exist, this operation will fail with
- * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]].
- *
- * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
- * If the table does not exist
- */
- @throws(classOf[CannotReplaceMissingTableException])
- def replace(): Unit
-
- /**
- * Create a new table or replace an existing table with the contents of the data frame.
- *
- * The output table's schema, partition layout, properties, and other configuration will be based
- * on the contents of the data frame and the configuration set on this writer. If the table
- * exists, its configuration and data will be replaced.
- */
- def createOrReplace(): Unit
-
- /**
- * Partition the output table created by `create`, `createOrReplace`, or `replace` using
- * the given columns or transforms.
- *
- * When specified, the table data will be stored by these values for efficient reads.
- *
- * For example, when a table is partitioned by day, it may be stored in a directory layout like:
- *
- *
`table/day=2019-06-01/`
- *
`table/day=2019-06-02/`
- *
- *
- * Partitioning is one of the most widely used techniques to optimize physical data layout.
- * It provides a coarse-grained index for skipping unnecessary data reads when queries have
- * predicates on the partitioned columns. In order for partitioning to work well, the number
- * of distinct values in each column should typically be less than tens of thousands.
- *
- * @since 3.0.0
- */
- @scala.annotation.varargs
- def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T]
-
- /**
- * Clusters the output by the given columns on the storage. The rows with matching values in
- * the specified clustering columns will be consolidated within the same group.
- *
- * For instance, if you cluster a dataset by date, the data sharing the same date will be stored
- * together in a file. This arrangement improves query efficiency when you apply selective
- * filters to these clustering columns, thanks to data skipping.
- *
- * @since 4.0.0
- */
- @scala.annotation.varargs
- def clusterBy(colName: String, colNames: String*): CreateTableWriter[T]
-
- /**
- * Specifies a provider for the underlying output data source. Spark's default catalog supports
- * "parquet", "json", etc.
- *
- * @since 3.0.0
- */
- def using(provider: String): CreateTableWriter[T]
-
- /**
- * Add a table property.
- */
- def tableProperty(property: String, value: String): CreateTableWriter[T]
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 38521e8e16f91..ef628ca612b49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Query
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder, StructEncoder}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
@@ -51,6 +52,7 @@ import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
@@ -60,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable}
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
-import org.apache.spark.sql.internal.{DataFrameWriterImpl, SQLConf}
+import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
import org.apache.spark.sql.streaming.DataStreamWriter
@@ -78,13 +80,14 @@ private[sql] object Dataset {
val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
- val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
+ val encoder = implicitly[Encoder[T]]
+ val dataset = new Dataset(sparkSession, logicalPlan, encoder)
// Eagerly bind the encoder so we verify that the encoder matches the underlying
// schema. The user will get an error if this is not the case.
// optimization: it is guaranteed that [[InternalRow]] can be converted to [[Row]] so
// do not do this check in that case. this check can be expensive since it requires running
// the whole [[Analyzer]] to resolve the deserializer
- if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) {
+ if (dataset.encoder.clsTag.runtimeClass != classOf[Row]) {
dataset.resolvedEnc
}
dataset
@@ -94,7 +97,7 @@ private[sql] object Dataset {
sparkSession.withActive {
val qe = sparkSession.sessionState.executePlan(logicalPlan)
qe.assertAnalyzed()
- new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+ new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
}
def ofRows(
@@ -105,7 +108,7 @@ private[sql] object Dataset {
val qe = new QueryExecution(
sparkSession, logicalPlan, shuffleCleanupMode = shuffleCleanupMode)
qe.assertAnalyzed()
- new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+ new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
}
/** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
@@ -118,7 +121,7 @@ private[sql] object Dataset {
val qe = new QueryExecution(
sparkSession, logicalPlan, tracker, shuffleCleanupMode = shuffleCleanupMode)
qe.assertAnalyzed()
- new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
+ new Dataset[Row](qe, RowEncoder.encoderFor(qe.analyzed.schema))
}
}
@@ -213,7 +216,8 @@ private[sql] object Dataset {
class Dataset[T] private[sql](
@DeveloperApi @Unstable @transient val queryExecution: QueryExecution,
@DeveloperApi @Unstable @transient val encoder: Encoder[T])
- extends api.Dataset[T, Dataset] {
+ extends api.Dataset[T] {
+ type DS[U] = Dataset[U]
type RGD = RelationalGroupedDataset
@transient lazy val sparkSession: SparkSession = {
@@ -243,7 +247,7 @@ class Dataset[T] private[sql](
@transient private[sql] val logicalPlan: LogicalPlan = {
val plan = queryExecution.commandExecuted
- if (sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
+ if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long])
dsIds.add(id)
plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
@@ -252,12 +256,17 @@ class Dataset[T] private[sql](
}
/**
- * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the
- * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use
- * it when constructing new Dataset objects that have the same object type (that will be
- * possibly resolved to a different schema).
+ * Expose the encoder as implicit so it can be used to construct new Dataset objects that have
+ * the same external type.
*/
- private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
+ private implicit def encoderImpl: Encoder[T] = encoder
+
+ /**
+ * The actual [[ExpressionEncoder]] used by the dataset. This and its resolved counterpart should
+ * only be used for actual (de)serialization, the binding of Aggregator inputs, and in the rare
+ * cases where a plan needs to be constructed with an ExpressionEncoder.
+ */
+ private[sql] lazy val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
// The resolved `ExpressionEncoder` which can be used to turn rows to objects of type T, after
// collecting rows to the driver side.
@@ -265,7 +274,7 @@ class Dataset[T] private[sql](
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
}
- private implicit def classTag: ClassTag[T] = exprEnc.clsTag
+ private implicit def classTag: ClassTag[T] = encoder.clsTag
// sqlContext must be val because a stable identifier is expected when you import implicits
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
@@ -476,7 +485,7 @@ class Dataset[T] private[sql](
/** @inheritdoc */
// This is declared with parentheses to prevent the Scala compiler from treating
// `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
- def toDF(): DataFrame = new Dataset[Row](queryExecution, ExpressionEncoder(schema))
+ def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder.encoderFor(schema))
/** @inheritdoc */
def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
@@ -671,17 +680,17 @@ class Dataset[T] private[sql](
Some(condition.expr),
JoinHint.NONE)).analyzed.asInstanceOf[Join]
- implicit val tuple2Encoder: Encoder[(T, U)] =
- ExpressionEncoder
- .tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = true)
- .asInstanceOf[Encoder[(T, U)]]
-
- withTypedPlan(JoinWith.typedJoinWith(
+ val leftEncoder = agnosticEncoderFor(encoder)
+ val rightEncoder = agnosticEncoderFor(other.encoder)
+ val joinEncoder = ProductEncoder.tuple(Seq(leftEncoder, rightEncoder), elementsCanBeNull = true)
+ .asInstanceOf[Encoder[(T, U)]]
+ val joinWith = JoinWith.typedJoinWith(
joined,
sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity,
sparkSession.sessionState.analyzer.resolver,
- this.exprEnc.isSerializedAsStructForTopLevel,
- other.exprEnc.isSerializedAsStructForTopLevel))
+ leftEncoder.isStruct,
+ rightEncoder.isStruct)
+ new Dataset(sparkSession, joinWith, joinEncoder)
}
// TODO(SPARK-22947): Fix the DataFrame API.
@@ -772,7 +781,7 @@ class Dataset[T] private[sql](
private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = {
val newExpr = expr transform {
case a: AttributeReference
- if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
+ if sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, id)
@@ -826,24 +835,29 @@ class Dataset[T] private[sql](
/** @inheritdoc */
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
- implicit val encoder: ExpressionEncoder[U1] = encoderFor(c1.encoder)
+ val encoder = agnosticEncoderFor(c1.encoder)
val tc1 = withInputType(c1.named, exprEnc, logicalPlan.output)
val project = Project(tc1 :: Nil, logicalPlan)
- if (!encoder.isSerializedAsStructForTopLevel) {
- new Dataset[U1](sparkSession, project, encoder)
- } else {
- // Flattens inner fields of U1
- new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
+ val plan = encoder match {
+ case se: StructEncoder[U1] =>
+ // Flatten the result.
+ val attribute = GetColumnByOrdinal(0, se.dataType)
+ val projectList = se.fields.zipWithIndex.map {
+ case (field, index) =>
+ Alias(GetStructField(attribute, index, None), field.name)()
+ }
+ Project(projectList, project)
+ case _ => project
}
+ new Dataset[U1](sparkSession, plan, encoder)
}
/** @inheritdoc */
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val encoders = columns.map(c => encoderFor(c.encoder))
+ val encoders = columns.map(c => agnosticEncoderFor(c.encoder))
val namedColumns = columns.map(c => withInputType(c.named, exprEnc, logicalPlan.output))
- val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
- new Dataset(execution, ExpressionEncoder.tuple(encoders))
+ new Dataset(sparkSession, Project(namedColumns, logicalPlan), ProductEncoder.tuple(encoders))
}
/** @inheritdoc */
@@ -851,24 +865,7 @@ class Dataset[T] private[sql](
Filter(condition.expr, logicalPlan)
}
- /**
- * Groups the Dataset using the specified columns, so we can run aggregation on them. See
- * [[RelationalGroupedDataset]] for all the available aggregate functions.
- *
- * {{{
- * // Compute the average for all numeric columns grouped by department.
- * ds.groupBy($"department").avg()
- *
- * // Compute the max age and average salary, grouped by department and gender.
- * ds.groupBy($"department", $"gender").agg(Map(
- * "salary" -> "avg",
- * "age" -> "max"
- * ))
- * }}}
- *
- * @group untypedrel
- * @since 2.0.0
- */
+ /** @inheritdoc */
@scala.annotation.varargs
def groupBy(cols: Column*): RelationalGroupedDataset = {
RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
@@ -900,35 +897,19 @@ class Dataset[T] private[sql](
rdd.reduce(func)
}
- /**
- * (Scala-specific)
- * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
- *
- * @group typedrel
- * @since 2.0.0
- */
+ /** @inheritdoc */
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
val withGroupingKey = AppendColumns(func, logicalPlan)
val executed = sparkSession.sessionState.executePlan(withGroupingKey)
new KeyValueGroupedDataset(
- encoderFor[K],
- encoderFor[T],
+ implicitly[Encoder[K]],
+ encoder,
executed,
logicalPlan.output,
withGroupingKey.newColumns)
}
- /**
- * (Java-specific)
- * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
- *
- * @group typedrel
- * @since 2.0.0
- */
- def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
- groupByKey(func.call(_))(encoder)
-
/** @inheritdoc */
def unpivot(
ids: Array[Column],
@@ -981,6 +962,22 @@ class Dataset[T] private[sql](
valueColumnName: String): DataFrame =
unpivot(ids.toArray, variableColumnName, valueColumnName)
+ /** @inheritdoc */
+ def transpose(indexColumn: Column): DataFrame = withPlan {
+ UnresolvedTranspose(
+ Seq(indexColumn.named),
+ logicalPlan
+ )
+ }
+
+ /** @inheritdoc */
+ def transpose(): DataFrame = withPlan {
+ UnresolvedTranspose(
+ Seq.empty,
+ logicalPlan
+ )
+ }
+
/** @inheritdoc */
@scala.annotation.varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan {
@@ -1362,12 +1359,6 @@ class Dataset[T] private[sql](
implicitly[Encoder[U]])
}
- /** @inheritdoc */
- def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
- val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
- mapPartitions(func)(encoder)
- }
-
/**
* Returns a new `DataFrame` that contains the result of applying a serialized R function
* `func` to each partition.
@@ -1377,7 +1368,11 @@ class Dataset[T] private[sql](
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
schema: StructType): DataFrame = {
- val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
+ val rowEncoder: ExpressionEncoder[Row] = if (isUnTyped) {
+ exprEnc.asInstanceOf[ExpressionEncoder[Row]]
+ } else {
+ ExpressionEncoder(schema)
+ }
Dataset.ofRows(
sparkSession,
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
@@ -1426,11 +1421,6 @@ class Dataset[T] private[sql](
Option(profile)))
}
- /** @inheritdoc */
- def foreach(f: T => Unit): Unit = withNewRDDExecutionId("foreach") {
- rdd.foreach(f)
- }
-
/** @inheritdoc */
def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId("foreachPartition") {
rdd.foreachPartition(f)
@@ -1606,25 +1596,7 @@ class Dataset[T] private[sql](
new DataFrameWriterImpl[T](this)
}
- /**
- * Create a write configuration builder for v2 sources.
- *
- * This builder is used to configure and execute write operations. For example, to append to an
- * existing table, run:
- *
- * {{{
- * df.writeTo("catalog.db.table").append()
- * }}}
- *
- * This can also be used to create or replace existing tables:
- *
- * {{{
- * df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace()
- * }}}
- *
- * @group basic
- * @since 3.0.0
- */
+ /** @inheritdoc */
def writeTo(table: String): DataFrameWriterV2[T] = {
// TODO: streaming could be adapted to use this interface
if (isStreaming) {
@@ -1632,31 +1604,10 @@ class Dataset[T] private[sql](
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("writeTo")))
}
- new DataFrameWriterV2[T](table, this)
+ new DataFrameWriterV2Impl[T](table, this)
}
- /**
- * Merges a set of updates, insertions, and deletions based on a source table into
- * a target table.
- *
- * Scala Examples:
- * {{{
- * spark.table("source")
- * .mergeInto("target", $"source.id" === $"target.id")
- * .whenMatched($"salary" === 100)
- * .delete()
- * .whenNotMatched()
- * .insertAll()
- * .whenNotMatchedBySource($"salary" === 100)
- * .update(Map(
- * "salary" -> lit(200)
- * ))
- * .merge()
- * }}}
- *
- * @group basic
- * @since 4.0.0
- */
+ /** @inheritdoc */
def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = {
if (isStreaming) {
logicalPlan.failAnalysis(
@@ -1664,7 +1615,7 @@ class Dataset[T] private[sql](
messageParameters = Map("methodName" -> toSQLId("mergeInto")))
}
- new MergeIntoWriter[T](table, this, condition)
+ new MergeIntoWriterImpl[T](table, this, condition)
}
/**
@@ -1684,7 +1635,7 @@ class Dataset[T] private[sql](
/** @inheritdoc */
override def toJSON: Dataset[String] = {
- val rowSchema = this.schema
+ val rowSchema = exprEnc.schema
val sessionLocalTimeZone = sparkSession.sessionState.conf.sessionLocalTimeZone
mapPartitions { iter =>
val writer = new CharArrayWriter()
@@ -1952,6 +1903,10 @@ class Dataset[T] private[sql](
override def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] =
super.dropDuplicatesWithinWatermark(col1, cols: _*)
+ /** @inheritdoc */
+ override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+ super.mapPartitions(f, encoder)
+
/** @inheritdoc */
override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = super.flatMap(func)
@@ -1959,9 +1914,6 @@ class Dataset[T] private[sql](
override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
super.flatMap(f, encoder)
- /** @inheritdoc */
- override def foreach(func: ForeachFunction[T]): Unit = super.foreach(func)
-
/** @inheritdoc */
override def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
super.foreachPartition(func)
@@ -2018,18 +1970,16 @@ class Dataset[T] private[sql](
@scala.annotation.varargs
override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*)
+ /** @inheritdoc */
+ override def groupByKey[K](
+ func: MapFunction[T, K],
+ encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
+ super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]
+
////////////////////////////////////////////////////////////////////////////
// For Python API
////////////////////////////////////////////////////////////////////////////
- /**
- * It adds a new long column with the name `name` that increases one by one.
- * This is for 'distributed-sequence' default index in pandas API on Spark.
- */
- private[sql] def withSequenceColumn(name: String) = {
- select(column(DistributedSequenceID()).alias(name), col("*"))
- }
-
/**
* Converts a JavaRDD to a PythonRDD.
*/
@@ -2257,7 +2207,7 @@ class Dataset[T] private[sql](
/** A convenient function to wrap a set based logical plan and produce a Dataset. */
@inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = {
- if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) {
+ if (isUnTyped) {
// Set operators widen types (change the schema), so we cannot reuse the row encoder.
Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]]
} else {
@@ -2265,6 +2215,8 @@ class Dataset[T] private[sql](
}
}
+ private def isUnTyped: Boolean = classTag.runtimeClass.isAssignableFrom(classOf[Row])
+
/** Returns a optimized plan for CommandResult, convert to `LocalRelation`. */
private def commandResultOptimized: Dataset[T] = {
logicalPlan match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
deleted file mode 100644
index 27012c471462d..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala
+++ /dev/null
@@ -1,136 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-/**
- * The abstract class for writing custom logic to process data generated by a query.
- * This is often used to write the output of a streaming query to arbitrary storage systems.
- * Any implementation of this base class will be used by Spark in the following way.
- *
- *
- *
A single instance of this class is responsible of all the data generated by a single task
- * in a query. In other words, one instance is responsible for processing one partition of the
- * data generated in a distributed manner.
- *
- *
Any implementation of this class must be serializable because each task will get a fresh
- * serialized-deserialized copy of the provided object. Hence, it is strongly recommended that
- * any initialization for writing data (e.g. opening a connection or starting a transaction)
- * is done after the `open(...)` method has been called, which signifies that the task is
- * ready to generate data.
- *
- *
The lifecycle of the methods are as follows.
- *
- *
- * For each partition with `partitionId`:
- * For each batch/epoch of streaming data (if its streaming query) with `epochId`:
- * Method `open(partitionId, epochId)` is called.
- * If `open` returns true:
- * For each row in the partition and batch/epoch, method `process(row)` is called.
- * Method `close(errorOrNull)` is called with error (if any) seen while processing rows.
- *
- *
- *
- *
- * Important points to note:
- *
- *
Spark doesn't guarantee same output for (partitionId, epochId), so deduplication
- * cannot be achieved with (partitionId, epochId). e.g. source provides different number of
- * partitions for some reason, Spark optimization changes number of partitions, etc.
- * Refer SPARK-28650 for more details. If you need deduplication on output, try out
- * `foreachBatch` instead.
- *
- *
The `close()` method will be called if `open()` method returns successfully (irrespective
- * of the return value), except if the JVM crashes in the middle.
- *
- *
- * Scala example:
- * {{{
- * datasetOfString.writeStream.foreach(new ForeachWriter[String] {
- *
- * def open(partitionId: Long, version: Long): Boolean = {
- * // open connection
- * }
- *
- * def process(record: String) = {
- * // write string to connection
- * }
- *
- * def close(errorOrNull: Throwable): Unit = {
- * // close the connection
- * }
- * })
- * }}}
- *
- * Java example:
- * {{{
- * datasetOfString.writeStream().foreach(new ForeachWriter() {
- *
- * @Override
- * public boolean open(long partitionId, long version) {
- * // open connection
- * }
- *
- * @Override
- * public void process(String value) {
- * // write string to connection
- * }
- *
- * @Override
- * public void close(Throwable errorOrNull) {
- * // close the connection
- * }
- * });
- * }}}
- *
- * @since 2.0.0
- */
-abstract class ForeachWriter[T] extends Serializable {
-
- // TODO: Move this to org.apache.spark.sql.util or consolidate this with batch API.
-
- /**
- * Called when starting to process one partition of new data in the executor. See the class
- * docs for more information on how to use the `partitionId` and `epochId`.
- *
- * @param partitionId the partition id.
- * @param epochId a unique id for data deduplication.
- * @return `true` if the corresponding partition and version id should be processed. `false`
- * indicates the partition should be skipped.
- */
- def open(partitionId: Long, epochId: Long): Boolean
-
- /**
- * Called to process the data in the executor side. This method will be called only if `open`
- * returns `true`.
- */
- def process(value: T): Unit
-
- /**
- * Called when stopping to process one partition of new data in the executor side. This is
- * guaranteed to be called either `open` returns `true` or `false`. However,
- * `close` won't be called in the following cases:
- *
- *
- *
JVM crashes without throwing a `Throwable`
- *
`open` throws a `Throwable`.
- *
- *
- * @param errorOrNull the error thrown during processing data or null if there was no error.
- */
- def close(errorOrNull: Throwable): Unit
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index e3ea33a7504bf..c645ba57e8f82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql
-import scala.jdk.CollectionConverters._
-
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark, UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder}
+import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
import org.apache.spark.sql.internal.TypedAggUtils.{aggKeyColumn, withInputType}
@@ -41,79 +41,41 @@ class KeyValueGroupedDataset[K, V] private[sql](
vEncoder: Encoder[V],
@transient val queryExecution: QueryExecution,
private val dataAttributes: Seq[Attribute],
- private val groupingAttributes: Seq[Attribute]) extends Serializable {
+ private val groupingAttributes: Seq[Attribute])
+ extends api.KeyValueGroupedDataset[K, V] {
+ type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]
- // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly.
- private implicit val kExprEnc: ExpressionEncoder[K] = encoderFor(kEncoder)
- private implicit val vExprEnc: ExpressionEncoder[V] = encoderFor(vEncoder)
+ private implicit def kEncoderImpl: Encoder[K] = kEncoder
+ private implicit def vEncoderImpl: Encoder[V] = vEncoder
private def logicalPlan = queryExecution.analyzed
private def sparkSession = queryExecution.sparkSession
import queryExecution.sparkSession._
- /**
- * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
- * specified type. The mapping of key columns to the type follows the same rules as `as` on
- * [[Dataset]].
- *
- * @since 1.6.0
- */
+ /** @inheritdoc */
def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] =
new KeyValueGroupedDataset(
- encoderFor[L],
- vExprEnc,
+ implicitly[Encoder[L]],
+ vEncoder,
queryExecution,
dataAttributes,
groupingAttributes)
- /**
- * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
- * to the data. The grouping key is unchanged by this.
- *
- * {{{
- * // Create values grouped by key from a Dataset[(K, V)]
- * ds.groupByKey(_._1).mapValues(_._2) // Scala
- * }}}
- *
- * @since 2.1.0
- */
+ /** @inheritdoc */
def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = {
val withNewData = AppendColumns(func, dataAttributes, logicalPlan)
val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData)
val executed = sparkSession.sessionState.executePlan(projected)
new KeyValueGroupedDataset(
- encoderFor[K],
- encoderFor[W],
+ kEncoder,
+ implicitly[Encoder[W]],
executed,
withNewData.newColumns,
groupingAttributes)
}
- /**
- * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
- * to the data. The grouping key is unchanged by this.
- *
- * {{{
- * // Create Integer values grouped by String key from a Dataset>
- * Dataset> ds = ...;
- * KeyValueGroupedDataset grouped =
- * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT());
- * }}}
- *
- * @since 2.1.0
- */
- def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
- implicit val uEnc = encoder
- mapValues { (v: V) => func.call(v) }
- }
-
- /**
- * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping
- * over the Dataset to extract the keys and then running a distinct operation on those.
- *
- * @since 1.6.0
- */
+ /** @inheritdoc */
def keys: Dataset[K] = {
Dataset[K](
sparkSession,
@@ -121,194 +83,23 @@ class KeyValueGroupedDataset[K, V] private[sql](
Project(groupingAttributes, logicalPlan)))
}
- /**
- * (Scala-specific)
- * Applies the given function to each group of data. For each unique group, the function will
- * be passed the group key and an iterator that contains all of the elements in the group. The
- * function can return an iterator containing elements of an arbitrary type which will be returned
- * as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
- *
- * @since 1.6.0
- */
- def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
- Dataset[U](
- sparkSession,
- MapGroups(
- f,
- groupingAttributes,
- dataAttributes,
- Seq.empty,
- logicalPlan))
- }
-
- /**
- * (Java-specific)
- * Applies the given function to each group of data. For each unique group, the function will
- * be passed the group key and an iterator that contains all of the elements in the group. The
- * function can return an iterator containing elements of an arbitrary type which will be returned
- * as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
- *
- * @since 1.6.0
- */
- def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
- flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
- }
-
- /**
- * (Scala-specific)
- * Applies the given function to each group of data. For each unique group, the function will
- * be passed the group key and a sorted iterator that contains all of the elements in the group.
- * The function can return an iterator containing elements of an arbitrary type which will be
- * returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
- *
- * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator
- * to be sorted according to the given sort expressions. That sorting does not add
- * computational complexity.
- *
- * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
- * @since 3.4.0
- */
+ /** @inheritdoc */
def flatMapSortedGroups[U : Encoder](
sortExprs: Column*)(
f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
- val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs.map(_.expr))
-
Dataset[U](
sparkSession,
MapGroups(
f,
groupingAttributes,
dataAttributes,
- sortOrder,
+ MapGroups.sortOrder(sortExprs.map(_.expr)),
logicalPlan
)
)
}
- /**
- * (Java-specific)
- * Applies the given function to each group of data. For each unique group, the function will
- * be passed the group key and a sorted iterator that contains all of the elements in the group.
- * The function can return an iterator containing elements of an arbitrary type which will be
- * returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
- *
- * This is equivalent to [[KeyValueGroupedDataset#flatMapGroups]], except for the iterator
- * to be sorted according to the given sort expressions. That sorting does not add
- * computational complexity.
- *
- * @see [[org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroups]]
- * @since 3.4.0
- */
- def flatMapSortedGroups[U](
- SortExprs: Array[Column],
- f: FlatMapGroupsFunction[K, V, U],
- encoder: Encoder[U]): Dataset[U] = {
- import org.apache.spark.util.ArrayImplicits._
- flatMapSortedGroups(
- SortExprs.toImmutableArraySeq: _*)((key, data) => f.call(key, data.asJava).asScala)(encoder)
- }
-
- /**
- * (Scala-specific)
- * Applies the given function to each group of data. For each unique group, the function will
- * be passed the group key and an iterator that contains all of the elements in the group. The
- * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
- *
- * @since 1.6.0
- */
- def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
- val func = (key: K, it: Iterator[V]) => Iterator(f(key, it))
- flatMapGroups(func)
- }
-
- /**
- * (Java-specific)
- * Applies the given function to each group of data. For each unique group, the function will
- * be passed the group key and an iterator that contains all of the elements in the group. The
- * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
- *
- * This function does not support partial aggregation, and as a result requires shuffling all
- * the data in the [[Dataset]]. If an application intends to perform an aggregation over each
- * key, it is best to use the reduce function or an
- * `org.apache.spark.sql.expressions#Aggregator`.
- *
- * Internally, the implementation will spill to disk if any given group is too large to fit into
- * memory. However, users must take care to avoid materializing the whole iterator for a group
- * (for example, by calling `toList`) unless they are sure that this is possible given the memory
- * constraints of their cluster.
- *
- * @since 1.6.0
- */
- def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
- mapGroups((key, data) => f.call(key, data.asJava))(encoder)
- }
-
- /**
- * (Scala-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 2.2.0
- */
+ /** @inheritdoc */
def mapGroupsWithState[S: Encoder, U: Encoder](
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
@@ -324,23 +115,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
child = logicalPlan))
}
- /**
- * (Scala-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 2.2.0
- */
+ /** @inheritdoc */
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
@@ -357,29 +132,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
child = logicalPlan))
}
- /**
- * (Scala-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See [[org.apache.spark.sql.streaming.GroupState]] for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param timeoutConf Timeout Conf, see GroupStateTimeout for more details
- * @param initialState The user provided state that will be initialized when the first batch
- * of data is processed in the streaming query. The user defined function
- * will be called on the state data even if there are no other values in
- * the group. To convert a Dataset ds of type Dataset[(K, S)] to a
- * KeyValueGroupedDataset[K, S]
- * do {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.2.0
- */
+ /** @inheritdoc */
def mapGroupsWithState[S: Encoder, U: Encoder](
timeoutConf: GroupStateTimeout,
initialState: KeyValueGroupedDataset[K, S])(
@@ -402,114 +155,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
))
}
- /**
- * (Java-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See `GroupState` for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param stateEncoder Encoder for the state type.
- * @param outputEncoder Encoder for the output type.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 2.2.0
- */
- def mapGroupsWithState[S, U](
- func: MapGroupsWithStateFunction[K, V, S, U],
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U]): Dataset[U] = {
- mapGroupsWithState[S, U](
- (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
- )(stateEncoder, outputEncoder)
- }
-
- /**
- * (Java-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See `GroupState` for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param stateEncoder Encoder for the state type.
- * @param outputEncoder Encoder for the output type.
- * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 2.2.0
- */
- def mapGroupsWithState[S, U](
- func: MapGroupsWithStateFunction[K, V, S, U],
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U],
- timeoutConf: GroupStateTimeout): Dataset[U] = {
- mapGroupsWithState[S, U](timeoutConf)(
- (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
- )(stateEncoder, outputEncoder)
- }
-
- /**
- * (Java-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See `GroupState` for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param stateEncoder Encoder for the state type.
- * @param outputEncoder Encoder for the output type.
- * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
- * @param initialState The user provided state that will be initialized when the first batch
- * of data is processed in the streaming query. The user defined function
- * will be called on the state data even if there are no other values in
- * the group.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.2.0
- */
- def mapGroupsWithState[S, U](
- func: MapGroupsWithStateFunction[K, V, S, U],
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U],
- timeoutConf: GroupStateTimeout,
- initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
- mapGroupsWithState[S, U](timeoutConf, initialState)(
- (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
- )(stateEncoder, outputEncoder)
- }
-
- /**
- * (Scala-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See `GroupState` for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param outputMode The output mode of the function.
- * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 2.2.0
- */
+ /** @inheritdoc */
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout)(
@@ -529,29 +175,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
child = logicalPlan))
}
- /**
- * (Scala-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See `GroupState` for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param outputMode The output mode of the function.
- * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
- * @param initialState The user provided state that will be initialized when the first batch
- * of data is processed in the streaming query. The user defined function
- * will be called on the state data even if there are no other values in
- * the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
- * to a `KeyValueGroupedDataset[K, S]`, use
- * {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.2.0
- */
+ /** @inheritdoc */
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
@@ -576,91 +200,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
))
}
- /**
- * (Java-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See `GroupState` for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param outputMode The output mode of the function.
- * @param stateEncoder Encoder for the state type.
- * @param outputEncoder Encoder for the output type.
- * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 2.2.0
- */
- def flatMapGroupsWithState[S, U](
- func: FlatMapGroupsWithStateFunction[K, V, S, U],
- outputMode: OutputMode,
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U],
- timeoutConf: GroupStateTimeout): Dataset[U] = {
- val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
- flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
- }
-
- /**
- * (Java-specific)
- * Applies the given function to each group of data, while maintaining a user-defined per-group
- * state. The result Dataset will represent the objects returned by the function.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger, and
- * updates to each group's state will be saved across invocations.
- * See `GroupState` for more details.
- *
- * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param func Function to be called on every group.
- * @param outputMode The output mode of the function.
- * @param stateEncoder Encoder for the state type.
- * @param outputEncoder Encoder for the output type.
- * @param timeoutConf Timeout configuration for groups that do not receive data for a while.
- * @param initialState The user provided state that will be initialized when the first batch
- * of data is processed in the streaming query. The user defined function
- * will be called on the state data even if there are no other values in
- * the group. To covert a Dataset `ds` of type of type `Dataset[(K, S)]`
- * to a `KeyValueGroupedDataset[K, S]`, use
- * {{{ ds.groupByKey(x => x._1).mapValues(_._2) }}}
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- * @since 3.2.0
- */
- def flatMapGroupsWithState[S, U](
- func: FlatMapGroupsWithStateFunction[K, V, S, U],
- outputMode: OutputMode,
- stateEncoder: Encoder[S],
- outputEncoder: Encoder[U],
- timeoutConf: GroupStateTimeout,
- initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
- val f = (key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s).asScala
- flatMapGroupsWithState[S, U](
- outputMode, timeoutConf, initialState)(f)(stateEncoder, outputEncoder)
- }
-
- /**
- * (Scala-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * We allow the user to act on per-group set of input rows along with keyed state and the
- * user can choose to output/return 0 or more rows.
- * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
- * in each trigger and the user's state/state variables will be stored persistently across
- * invocations.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked
- * by the operator.
- * @param timeMode The time mode semantics of the stateful processor for timers and TTL.
- * @param outputMode The output mode of the stateful processor.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
+ /** @inheritdoc */
private[sql] def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
timeMode: TimeMode,
@@ -678,29 +218,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
)
}
- /**
- * (Scala-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * We allow the user to act on per-group set of input rows along with keyed state and the
- * user can choose to output/return 0 or more rows.
- * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
- * in each trigger and the user's state/state variables will be stored persistently across
- * invocations.
- *
- * Downstream operators would use specified eventTimeColumnName to calculate watermark.
- * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor Instance of statefulProcessor whose functions will
- * be invoked by the operator.
- * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
- * transformWithState will use the new eventTimeColumn. The user
- * needs to ensure that the eventTime for emitted output adheres to
- * the watermark boundary, otherwise streaming query will fail.
- * @param outputMode The output mode of the stateful processor.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
+ /** @inheritdoc */
private[sql] def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
eventTimeColumnName: String,
@@ -716,81 +234,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
updateEventTimeColumnAfterTransformWithState(transformWithState, eventTimeColumnName)
}
- /**
- * (Java-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * We allow the user to act on per-group set of input rows along with keyed state and the
- * user can choose to output/return 0 or more rows.
- * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
- * in each trigger and the user's state/state variables will be stored persistently across
- * invocations.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
- * operator.
- * @param timeMode The time mode semantics of the stateful processor for timers and TTL.
- * @param outputMode The output mode of the stateful processor.
- * @param outputEncoder Encoder for the output type.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
- private[sql] def transformWithState[U: Encoder](
- statefulProcessor: StatefulProcessor[K, V, U],
- timeMode: TimeMode,
- outputMode: OutputMode,
- outputEncoder: Encoder[U]): Dataset[U] = {
- transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder)
- }
-
- /**
- * (Java-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * We allow the user to act on per-group set of input rows along with keyed state and the
- * user can choose to output/return 0 or more rows.
- *
- * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
- * in each trigger and the user's state/state variables will be stored persistently across
- * invocations.
- *
- * Downstream operators would use specified eventTimeColumnName to calculate watermark.
- * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
- * operator.
- * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
- * transformWithState will use the new eventTimeColumn. The user
- * needs to ensure that the eventTime for emitted output adheres to
- * the watermark boundary, otherwise streaming query will fail.
- * @param outputMode The output mode of the stateful processor.
- * @param outputEncoder Encoder for the output type.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
- private[sql] def transformWithState[U: Encoder](
- statefulProcessor: StatefulProcessor[K, V, U],
- eventTimeColumnName: String,
- outputMode: OutputMode,
- outputEncoder: Encoder[U]): Dataset[U] = {
- transformWithState(statefulProcessor, eventTimeColumnName, outputMode)(outputEncoder)
- }
-
- /**
- * (Scala-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * Functions as the function above, but with additional initial state.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor Instance of statefulProcessor whose functions will
- * be invoked by the operator.
- * @param timeMode The time mode semantics of the stateful processor for timers and TTL.
- * @param outputMode The output mode of the stateful processor.
- * @param initialState User provided initial state that will be used to initiate state for
- * the query in the first batch.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
+ /** @inheritdoc */
private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
timeMode: TimeMode,
@@ -812,29 +256,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
)
}
- /**
- * (Scala-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * Functions as the function above, but with additional eventTimeColumnName for output.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
- *
- * Downstream operators would use specified eventTimeColumnName to calculate watermark.
- * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
- *
- * @param statefulProcessor Instance of statefulProcessor whose functions will
- * be invoked by the operator.
- * @param eventTimeColumnName eventTime column in the output dataset. Any operations after
- * transformWithState will use the new eventTimeColumn. The user
- * needs to ensure that the eventTime for emitted output adheres to
- * the watermark boundary, otherwise streaming query will fail.
- * @param outputMode The output mode of the stateful processor.
- * @param initialState User provided initial state that will be used to initiate state for
- * the query in the first batch.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
+ /** @inheritdoc */
private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
eventTimeColumnName: String,
@@ -855,71 +277,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
updateEventTimeColumnAfterTransformWithState(transformWithState, eventTimeColumnName)
}
- /**
- * (Java-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * Functions as the function above, but with additional initialStateEncoder for state encoding.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor Instance of statefulProcessor whose functions will
- * be invoked by the operator.
- * @param timeMode The time mode semantics of the stateful processor for
- * timers and TTL.
- * @param outputMode The output mode of the stateful processor.
- * @param initialState User provided initial state that will be used to initiate state for
- * the query in the first batch.
- * @param outputEncoder Encoder for the output type.
- * @param initialStateEncoder Encoder for the initial state type.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
- private[sql] def transformWithState[U: Encoder, S: Encoder](
- statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
- timeMode: TimeMode,
- outputMode: OutputMode,
- initialState: KeyValueGroupedDataset[K, S],
- outputEncoder: Encoder[U],
- initialStateEncoder: Encoder[S]): Dataset[U] = {
- transformWithState(statefulProcessor, timeMode,
- outputMode, initialState)(outputEncoder, initialStateEncoder)
- }
-
- /**
- * (Java-specific)
- * Invokes methods defined in the stateful processor used in arbitrary state API v2.
- * Functions as the function above, but with additional eventTimeColumnName for output.
- *
- * Downstream operators would use specified eventTimeColumnName to calculate watermark.
- * Note that TimeMode is set to EventTime to ensure correct flow of watermark.
- *
- * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
- * @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
- * @param statefulProcessor Instance of statefulProcessor whose functions will
- * be invoked by the operator.
- * @param outputMode The output mode of the stateful processor.
- * @param initialState User provided initial state that will be used to initiate state for
- * the query in the first batch.
- * @param eventTimeColumnName event column in the output dataset. Any operations after
- * transformWithState will use the new eventTimeColumn. The user
- * needs to ensure that the eventTime for emitted output adheres to
- * the watermark boundary, otherwise streaming query will fail.
- * @param outputEncoder Encoder for the output type.
- * @param initialStateEncoder Encoder for the initial state type.
- *
- * See [[Encoder]] for more details on what types are encodable to Spark SQL.
- */
- private[sql] def transformWithState[U: Encoder, S: Encoder](
- statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
- outputMode: OutputMode,
- initialState: KeyValueGroupedDataset[K, S],
- eventTimeColumnName: String,
- outputEncoder: Encoder[U],
- initialStateEncoder: Encoder[S]): Dataset[U] = {
- transformWithState(statefulProcessor, eventTimeColumnName,
- outputMode, initialState)(outputEncoder, initialStateEncoder)
- }
-
/**
* Creates a new dataset with updated eventTimeColumn after the transformWithState
* logical node.
@@ -939,125 +296,238 @@ class KeyValueGroupedDataset[K, V] private[sql](
transformWithStateDataset.logicalPlan)))
}
- /**
- * (Scala-specific)
- * Reduces the elements of each group of data using the specified binary function.
- * The given function must be commutative and associative or the result may be non-deterministic.
- *
- * @since 1.6.0
- */
+ /** @inheritdoc */
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
- val vEncoder = encoderFor[V]
val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn
agg(aggregator)
}
- /**
- * (Java-specific)
- * Reduces the elements of each group of data using the specified binary function.
- * The given function must be commutative and associative or the result may be non-deterministic.
- *
- * @since 1.6.0
- */
- def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
- reduceGroups(f.call _)
- }
-
- /**
- * Internal helper function for building typed aggregations that return tuples. For simplicity
- * and code reuse, we do this without the help of the type system and then use helper functions
- * that cast appropriately for the user facing interface.
- */
+ /** @inheritdoc */
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val encoders = columns.map(c => encoderFor(c.encoder))
- val namedColumns = columns.map(c => withInputType(c.named, vExprEnc, dataAttributes))
- val keyColumn = aggKeyColumn(kExprEnc, groupingAttributes)
+ val keyAgEncoder = agnosticEncoderFor(kEncoder)
+ val valueExprEncoder = encoderFor(vEncoder)
+ val encoders = columns.map(c => agnosticEncoderFor(c.encoder))
+ val namedColumns = columns.map { c =>
+ withInputType(c.named, valueExprEncoder, dataAttributes)
+ }
+ val keyColumn = aggKeyColumn(keyAgEncoder, groupingAttributes)
val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
- val execution = new QueryExecution(sparkSession, aggregate)
+ new Dataset(sparkSession, aggregate, ProductEncoder.tuple(keyAgEncoder +: encoders))
+ }
+
+ /** @inheritdoc */
+ def cogroupSorted[U, R : Encoder](
+ other: KeyValueGroupedDataset[K, U])(
+ thisSortExprs: Column*)(
+ otherSortExprs: Column*)(
+ f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
+ implicit val uEncoder = other.vEncoderImpl
+ Dataset[R](
+ sparkSession,
+ CoGroup(
+ f,
+ this.groupingAttributes,
+ other.groupingAttributes,
+ this.dataAttributes,
+ other.dataAttributes,
+ MapGroups.sortOrder(thisSortExprs.map(_.expr)),
+ MapGroups.sortOrder(otherSortExprs.map(_.expr)),
+ this.logicalPlan,
+ other.logicalPlan))
+ }
- new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders))
+ override def toString: String = {
+ val builder = new StringBuilder
+ val kFields = kEncoder.schema.map { f =>
+ s"${f.name}: ${f.dataType.simpleString(2)}"
+ }
+ val vFields = vEncoder.schema.map { f =>
+ s"${f.name}: ${f.dataType.simpleString(2)}"
+ }
+ builder.append("KeyValueGroupedDataset: [key: [")
+ builder.append(kFields.take(2).mkString(", "))
+ if (kFields.length > 2) {
+ builder.append(" ... " + (kFields.length - 2) + " more field(s)")
+ }
+ builder.append("], value: [")
+ builder.append(vFields.take(2).mkString(", "))
+ if (vFields.length > 2) {
+ builder.append(" ... " + (vFields.length - 2) + " more field(s)")
+ }
+ builder.append("]]").toString()
}
- /**
- * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing this aggregation over all elements in the group.
- *
- * @since 1.6.0
- */
- def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
- aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
+ ////////////////////////////////////////////////////////////////////////////
+ // Return type overrides to make sure we return the implementation instead
+ // of the interface.
+ ////////////////////////////////////////////////////////////////////////////
+ /** @inheritdoc */
+ override def mapValues[W](
+ func: MapFunction[V, W],
+ encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = super.mapValues(func, encoder)
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing these aggregations over all elements in the group.
- *
- * @since 1.6.0
- */
- def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
- aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
+ /** @inheritdoc */
+ override def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] =
+ super.flatMapGroups(f)
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing these aggregations over all elements in the group.
- *
- * @since 1.6.0
- */
- def agg[U1, U2, U3](
+ /** @inheritdoc */
+ override def flatMapGroups[U](
+ f: FlatMapGroupsFunction[K, V, U],
+ encoder: Encoder[U]): Dataset[U] = super.flatMapGroups(f, encoder)
+
+ /** @inheritdoc */
+ override def flatMapSortedGroups[U](
+ SortExprs: Array[Column],
+ f: FlatMapGroupsFunction[K, V, U],
+ encoder: Encoder[U]): Dataset[U] = super.flatMapSortedGroups(SortExprs, f, encoder)
+
+ /** @inheritdoc */
+ override def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = super.mapGroups(f)
+
+ /** @inheritdoc */
+ override def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] =
+ super.mapGroups(f, encoder)
+
+ /** @inheritdoc */
+ override def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] =
+ super.mapGroupsWithState(func, stateEncoder, outputEncoder)
+
+ /** @inheritdoc */
+ override def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout): Dataset[U] =
+ super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf)
+
+ /** @inheritdoc */
+ override def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+ super.mapGroupsWithState(func, stateEncoder, outputEncoder, timeoutConf, initialState)
+
+ /** @inheritdoc */
+ override def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout): Dataset[U] =
+ super.flatMapGroupsWithState(func, outputMode, stateEncoder, outputEncoder, timeoutConf)
+
+ /** @inheritdoc */
+ override def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] =
+ super.flatMapGroupsWithState(
+ func,
+ outputMode,
+ stateEncoder,
+ outputEncoder,
+ timeoutConf,
+ initialState)
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ outputEncoder: Encoder[U]) =
+ super.transformWithState(statefulProcessor, timeMode, outputMode, outputEncoder)
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder](
+ statefulProcessor: StatefulProcessor[K, V, U],
+ eventTimeColumnName: String,
+ outputMode: OutputMode,
+ outputEncoder: Encoder[U]) =
+ super.transformWithState(statefulProcessor, eventTimeColumnName, outputMode, outputEncoder)
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ timeMode: TimeMode,
+ outputMode: OutputMode,
+ initialState: KeyValueGroupedDataset[K, S],
+ outputEncoder: Encoder[U],
+ initialStateEncoder: Encoder[S]) = super.transformWithState(
+ statefulProcessor,
+ timeMode,
+ outputMode,
+ initialState,
+ outputEncoder,
+ initialStateEncoder)
+
+ /** @inheritdoc */
+ override private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ outputMode: OutputMode,
+ initialState: KeyValueGroupedDataset[K, S],
+ eventTimeColumnName: String,
+ outputEncoder: Encoder[U],
+ initialStateEncoder: Encoder[S]) = super.transformWithState(
+ statefulProcessor,
+ outputMode,
+ initialState,
+ eventTimeColumnName,
+ outputEncoder,
+ initialStateEncoder)
+
+ /** @inheritdoc */
+ override def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = super.reduceGroups(f)
+
+ /** @inheritdoc */
+ override def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = super.agg(col1)
+
+ /** @inheritdoc */
+ override def agg[U1, U2](
+ col1: TypedColumn[V, U1],
+ col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = super.agg(col1, col2)
+
+ /** @inheritdoc */
+ override def agg[U1, U2, U3](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
- col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
- aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
+ col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = super.agg(col1, col2, col3)
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing these aggregations over all elements in the group.
- *
- * @since 1.6.0
- */
- def agg[U1, U2, U3, U4](
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3],
- col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
- aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
+ col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = super.agg(col1, col2, col3, col4)
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing these aggregations over all elements in the group.
- *
- * @since 3.0.0
- */
- def agg[U1, U2, U3, U4, U5](
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3],
col4: TypedColumn[V, U4],
col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
- aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]]
+ super.agg(col1, col2, col3, col4, col5)
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing these aggregations over all elements in the group.
- *
- * @since 3.0.0
- */
- def agg[U1, U2, U3, U4, U5, U6](
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5, U6](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3],
col4: TypedColumn[V, U4],
col5: TypedColumn[V, U5],
col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
- aggUntyped(col1, col2, col3, col4, col5, col6)
- .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]]
+ super.agg(col1, col2, col3, col4, col5, col6)
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing these aggregations over all elements in the group.
- *
- * @since 3.0.0
- */
- def agg[U1, U2, U3, U4, U5, U6, U7](
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5, U6, U7](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3],
@@ -1065,16 +535,10 @@ class KeyValueGroupedDataset[K, V] private[sql](
col5: TypedColumn[V, U5],
col6: TypedColumn[V, U6],
col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
- aggUntyped(col1, col2, col3, col4, col5, col6, col7)
- .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]]
+ super.agg(col1, col2, col3, col4, col5, col6, col7)
- /**
- * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
- * and the result of computing these aggregations over all elements in the group.
- *
- * @since 3.0.0
- */
- def agg[U1, U2, U3, U4, U5, U6, U7, U8](
+ /** @inheritdoc */
+ override def agg[U1, U2, U3, U4, U5, U6, U7, U8](
col1: TypedColumn[V, U1],
col2: TypedColumn[V, U2],
col3: TypedColumn[V, U3],
@@ -1083,146 +547,30 @@ class KeyValueGroupedDataset[K, V] private[sql](
col6: TypedColumn[V, U6],
col7: TypedColumn[V, U7],
col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
- aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8)
- .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]]
+ super.agg(col1, col2, col3, col4, col5, col6, col7, col8)
- /**
- * Returns a [[Dataset]] that contains a tuple with each key and the number of items present
- * for that key.
- *
- * @since 1.6.0
- */
- def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]()))
+ /** @inheritdoc */
+ override def count(): Dataset[(K, Long)] = super.count()
- /**
- * (Scala-specific)
- * Applies the given function to each cogrouped data. For each unique group, the function will
- * be passed the grouping key and 2 iterators containing all elements in the group from
- * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
- * arbitrary type which will be returned as a new [[Dataset]].
- *
- * @since 1.6.0
- */
- def cogroup[U, R : Encoder](
+ /** @inheritdoc */
+ override def cogroup[U, R: Encoder](
other: KeyValueGroupedDataset[K, U])(
- f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
- implicit val uEncoder = other.vExprEnc
- Dataset[R](
- sparkSession,
- CoGroup(
- f,
- this.groupingAttributes,
- other.groupingAttributes,
- this.dataAttributes,
- other.dataAttributes,
- Seq.empty,
- Seq.empty,
- this.logicalPlan,
- other.logicalPlan))
- }
+ f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] =
+ super.cogroup(other)(f)
- /**
- * (Java-specific)
- * Applies the given function to each cogrouped data. For each unique group, the function will
- * be passed the grouping key and 2 iterators containing all elements in the group from
- * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
- * arbitrary type which will be returned as a new [[Dataset]].
- *
- * @since 1.6.0
- */
- def cogroup[U, R](
+ /** @inheritdoc */
+ override def cogroup[U, R](
other: KeyValueGroupedDataset[K, U],
f: CoGroupFunction[K, V, U, R],
- encoder: Encoder[R]): Dataset[R] = {
- cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
- }
-
- /**
- * (Scala-specific)
- * Applies the given function to each sorted cogrouped data. For each unique group, the function
- * will be passed the grouping key and 2 sorted iterators containing all elements in the group
- * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
- * of an arbitrary type which will be returned as a new [[Dataset]].
- *
- * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators
- * to be sorted according to the given sort expressions. That sorting does not add
- * computational complexity.
- *
- * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
- * @since 3.4.0
- */
- def cogroupSorted[U, R : Encoder](
- other: KeyValueGroupedDataset[K, U])(
- thisSortExprs: Column*)(
- otherSortExprs: Column*)(
- f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = {
- def toSortOrder(col: Column): SortOrder = col.expr match {
- case expr: SortOrder => expr
- case expr: Expression => SortOrder(expr, Ascending)
- }
+ encoder: Encoder[R]): Dataset[R] =
+ super.cogroup(other, f, encoder)
- val thisSortOrder: Seq[SortOrder] = thisSortExprs.map(toSortOrder)
- val otherSortOrder: Seq[SortOrder] = otherSortExprs.map(toSortOrder)
-
- implicit val uEncoder = other.vExprEnc
- Dataset[R](
- sparkSession,
- CoGroup(
- f,
- this.groupingAttributes,
- other.groupingAttributes,
- this.dataAttributes,
- other.dataAttributes,
- thisSortOrder,
- otherSortOrder,
- this.logicalPlan,
- other.logicalPlan))
- }
-
- /**
- * (Java-specific)
- * Applies the given function to each sorted cogrouped data. For each unique group, the function
- * will be passed the grouping key and 2 sorted iterators containing all elements in the group
- * from [[Dataset]] `this` and `other`. The function can return an iterator containing elements
- * of an arbitrary type which will be returned as a new [[Dataset]].
- *
- * This is equivalent to [[KeyValueGroupedDataset#cogroup]], except for the iterators
- * to be sorted according to the given sort expressions. That sorting does not add
- * computational complexity.
- *
- * @see [[org.apache.spark.sql.KeyValueGroupedDataset#cogroup]]
- * @since 3.4.0
- */
- def cogroupSorted[U, R](
+ /** @inheritdoc */
+ override def cogroupSorted[U, R](
other: KeyValueGroupedDataset[K, U],
thisSortExprs: Array[Column],
otherSortExprs: Array[Column],
f: CoGroupFunction[K, V, U, R],
- encoder: Encoder[R]): Dataset[R] = {
- import org.apache.spark.util.ArrayImplicits._
- cogroupSorted(other)(
- thisSortExprs.toImmutableArraySeq: _*)(otherSortExprs.toImmutableArraySeq: _*)(
- (key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
- }
-
- override def toString: String = {
- val builder = new StringBuilder
- val kFields = kExprEnc.schema.map {
- case f => s"${f.name}: ${f.dataType.simpleString(2)}"
- }
- val vFields = vExprEnc.schema.map {
- case f => s"${f.name}: ${f.dataType.simpleString(2)}"
- }
- builder.append("KeyValueGroupedDataset: [key: [")
- builder.append(kFields.take(2).mkString(", "))
- if (kFields.length > 2) {
- builder.append(" ... " + (kFields.length - 2) + " more field(s)")
- }
- builder.append("], value: [")
- builder.append(vFields.take(2).mkString(", "))
- if (vFields.length > 2) {
- builder.append(" ... " + (vFields.length - 2) + " more field(s)")
- }
- builder.append("]]").toString()
- }
+ encoder: Encoder[R]): Dataset[R] =
+ super.cogroupSorted(other, thisSortExprs, otherSortExprs, f, encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
deleted file mode 100644
index 6212a7fdb259c..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/MergeIntoWriter.scala
+++ /dev/null
@@ -1,353 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.SparkRuntimeException
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{Assignment, DeleteAction, InsertAction, InsertStarAction, MergeAction, MergeIntoTable, UpdateAction, UpdateStarAction}
-import org.apache.spark.sql.functions.expr
-
-/**
- * `MergeIntoWriter` provides methods to define and execute merge actions based
- * on specified conditions.
- *
- * @tparam T the type of data in the Dataset.
- * @param table the name of the target table for the merge operation.
- * @param ds the source Dataset to merge into the target table.
- * @param on the merge condition.
- * @param schemaEvolutionEnabled whether to enable automatic schema evolution for this merge
- * operation. Default is `false`.
- *
- * @since 4.0.0
- */
-@Experimental
-class MergeIntoWriter[T] private[sql] (
- table: String,
- ds: Dataset[T],
- on: Column,
- private[sql] val schemaEvolutionEnabled: Boolean = false) {
-
- private val df: DataFrame = ds.toDF()
-
- private[sql] val sparkSession = ds.sparkSession
- import sparkSession.RichColumn
-
- private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
-
- private val logicalPlan = df.queryExecution.logical
-
- private[sql] var matchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedActions: Seq[MergeAction] = Seq.empty[MergeAction]
- private[sql] var notMatchedBySourceActions: Seq[MergeAction] = Seq.empty[MergeAction]
-
- /**
- * Initialize a `WhenMatched` action without any condition.
- *
- * This `WhenMatched` action will be executed when a source row matches a target table row based
- * on the merge condition.
- *
- * This `WhenMatched` can be followed by one of the following merge actions:
- * - `updateAll`: Update all the matched target table rows with source dataset rows.
- * - `update(Map)`: Update all the matched target table rows while changing only
- * a subset of columns based on the provided assignment.
- * - `delete`: Delete all target rows that have a match in the source table.
- *
- * @return a new `WhenMatched` object.
- */
- def whenMatched(): WhenMatched[T] = {
- new WhenMatched[T](this, None)
- }
-
- /**
- * Initialize a `WhenMatched` action with a condition.
- *
- * This `WhenMatched` action will be executed when a source row matches a target table row based
- * on the merge condition and the specified `condition` is satisfied.
- *
- * This `WhenMatched` can be followed by one of the following merge actions:
- * - `updateAll`: Update all the matched target table rows with source dataset rows.
- * - `update(Map)`: Update all the matched target table rows while changing only
- * a subset of columns based on the provided assignment.
- * - `delete`: Delete all target rows that have a match in the source table.
- *
- * @param condition a `Column` representing the condition to be evaluated for the action.
- * @return a new `WhenMatched` object configured with the specified condition.
- */
- def whenMatched(condition: Column): WhenMatched[T] = {
- new WhenMatched[T](this, Some(condition.expr))
- }
-
- /**
- * Initialize a `WhenNotMatched` action without any condition.
- *
- * This `WhenNotMatched` action will be executed when a source row does not match any target row
- * based on the merge condition.
- *
- * This `WhenNotMatched` can be followed by one of the following merge actions:
- * - `insertAll`: Insert all rows from the source that are not already in the target table.
- * - `insert(Map)`: Insert all rows from the source that are not already in the target table,
- * with the specified columns based on the provided assignment.
- *
- * @return a new `WhenNotMatched` object.
- */
- def whenNotMatched(): WhenNotMatched[T] = {
- new WhenNotMatched[T](this, None)
- }
-
- /**
- * Initialize a `WhenNotMatched` action with a condition.
- *
- * This `WhenNotMatched` action will be executed when a source row does not match any target row
- * based on the merge condition and the specified `condition` is satisfied.
- *
- * This `WhenNotMatched` can be followed by one of the following merge actions:
- * - `insertAll`: Insert all rows from the source that are not already in the target table.
- * - `insert(Map)`: Insert all rows from the source that are not already in the target table,
- * with the specified columns based on the provided assignment.
- *
- * @param condition a `Column` representing the condition to be evaluated for the action.
- * @return a new `WhenNotMatched` object configured with the specified condition.
- */
- def whenNotMatched(condition: Column): WhenNotMatched[T] = {
- new WhenNotMatched[T](this, Some(condition.expr))
- }
-
- /**
- * Initialize a `WhenNotMatchedBySource` action without any condition.
- *
- * This `WhenNotMatchedBySource` action will be executed when a target row does not match any
- * rows in the source table based on the merge condition.
- *
- * This `WhenNotMatchedBySource` can be followed by one of the following merge actions:
- * - `updateAll`: Update all the not matched target table rows with source dataset rows.
- * - `update(Map)`: Update all the not matched target table rows while changing only
- * the specified columns based on the provided assignment.
- * - `delete`: Delete all target rows that have no matches in the source table.
- *
- * @return a new `WhenNotMatchedBySource` object.
- */
- def whenNotMatchedBySource(): WhenNotMatchedBySource[T] = {
- new WhenNotMatchedBySource[T](this, None)
- }
-
- /**
- * Initialize a `WhenNotMatchedBySource` action with a condition.
- *
- * This `WhenNotMatchedBySource` action will be executed when a target row does not match any
- * rows in the source table based on the merge condition and the specified `condition`
- * is satisfied.
- *
- * This `WhenNotMatchedBySource` can be followed by one of the following merge actions:
- * - `updateAll`: Update all the not matched target table rows with source dataset rows.
- * - `update(Map)`: Update all the not matched target table rows while changing only
- * the specified columns based on the provided assignment.
- * - `delete`: Delete all target rows that have no matches in the source table.
- *
- * @param condition a `Column` representing the condition to be evaluated for the action.
- * @return a new `WhenNotMatchedBySource` object configured with the specified condition.
- */
- def whenNotMatchedBySource(condition: Column): WhenNotMatchedBySource[T] = {
- new WhenNotMatchedBySource[T](this, Some(condition.expr))
- }
-
- /**
- * Enable automatic schema evolution for this merge operation.
- * @return A `MergeIntoWriter` instance with schema evolution enabled.
- */
- def withSchemaEvolution(): MergeIntoWriter[T] = {
- new MergeIntoWriter[T](this.table, this.ds, this.on, schemaEvolutionEnabled = true)
- .withNewMatchedActions(this.matchedActions: _*)
- .withNewNotMatchedActions(this.notMatchedActions: _*)
- .withNewNotMatchedBySourceActions(this.notMatchedBySourceActions: _*)
- }
-
- /**
- * Executes the merge operation.
- */
- def merge(): Unit = {
- if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
- throw new SparkRuntimeException(
- errorClass = "NO_MERGE_ACTION_SPECIFIED",
- messageParameters = Map.empty)
- }
-
- val merge = MergeIntoTable(
- UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges(
- matchedActions, notMatchedActions, notMatchedBySourceActions)),
- logicalPlan,
- on.expr,
- matchedActions,
- notMatchedActions,
- notMatchedBySourceActions,
- schemaEvolutionEnabled)
- val qe = sparkSession.sessionState.executePlan(merge)
- qe.assertCommandExecuted()
- }
-
- private[sql] def withNewMatchedActions(actions: MergeAction*): MergeIntoWriter[T] = {
- this.matchedActions ++= actions
- this
- }
-
- private[sql] def withNewNotMatchedActions(actions: MergeAction*): MergeIntoWriter[T] = {
- this.notMatchedActions ++= actions
- this
- }
-
- private[sql] def withNewNotMatchedBySourceActions(actions: MergeAction*): MergeIntoWriter[T] = {
- this.notMatchedBySourceActions ++= actions
- this
- }
-}
-
-/**
- * A class for defining actions to be taken when matching rows in a DataFrame during
- * a merge operation.
- *
- * @param mergeIntoWriter The MergeIntoWriter instance responsible for writing data to a
- * target DataFrame.
- * @param condition An optional condition Expression that specifies when the actions
- * should be applied.
- * If the condition is None, the actions will be applied to all matched
- * rows.
- *
- * @tparam T The type of data in the MergeIntoWriter.
- */
-case class WhenMatched[T] private[sql](
- mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Expression]) {
- import mergeIntoWriter.sparkSession.RichColumn
-
- /**
- * Specifies an action to update all matched rows in the DataFrame.
- *
- * @return The MergeIntoWriter instance with the update all action configured.
- */
- def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(UpdateStarAction(condition))
- }
-
- /**
- * Specifies an action to update matched rows in the DataFrame with the provided column
- * assignments.
- *
- * @param map A Map of column names to Column expressions representing the updates to be applied.
- * @return The MergeIntoWriter instance with the update action configured.
- */
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(
- UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq))
- }
-
- /**
- * Specifies an action to delete matched rows from the DataFrame.
- *
- * @return The MergeIntoWriter instance with the delete action configured.
- */
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewMatchedActions(DeleteAction(condition))
- }
-}
-
-/**
- * A class for defining actions to be taken when no matching rows are found in a DataFrame
- * during a merge operation.
- *
- * @param MergeIntoWriter The MergeIntoWriter instance responsible for writing data to a
- * target DataFrame.
- * @param condition An optional condition Expression that specifies when the actions
- * defined in this configuration should be applied.
- * If the condition is None, the actions will be applied when there
- * are no matching rows.
- * @tparam T The type of data in the MergeIntoWriter.
- */
-case class WhenNotMatched[T] private[sql](
- mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Expression]) {
- import mergeIntoWriter.sparkSession.RichColumn
-
- /**
- * Specifies an action to insert all non-matched rows into the DataFrame.
- *
- * @return The MergeIntoWriter instance with the insert all action configured.
- */
- def insertAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(InsertStarAction(condition))
- }
-
- /**
- * Specifies an action to insert non-matched rows into the DataFrame with the provided
- * column assignments.
- *
- * @param map A Map of column names to Column expressions representing the values to be inserted.
- * @return The MergeIntoWriter instance with the insert action configured.
- */
- def insert(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedActions(
- InsertAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq))
- }
-}
-
-
-/**
- * A class for defining actions to be performed when there is no match by source
- * during a merge operation in a MergeIntoWriter.
- *
- * @param MergeIntoWriter the MergeIntoWriter instance to which the merge actions will be applied.
- * @param condition an optional condition to be used with the merge actions.
- * @tparam T the type parameter for the MergeIntoWriter.
- */
-case class WhenNotMatchedBySource[T] private[sql](
- mergeIntoWriter: MergeIntoWriter[T],
- condition: Option[Expression]) {
- import mergeIntoWriter.sparkSession.RichColumn
-
- /**
- * Specifies an action to update all non-matched rows in the target DataFrame when
- * not matched by the source.
- *
- * @return The MergeIntoWriter instance with the update all action configured.
- */
- def updateAll(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(UpdateStarAction(condition))
- }
-
- /**
- * Specifies an action to update non-matched rows in the target DataFrame with the provided
- * column assignments when not matched by the source.
- *
- * @param map A Map of column names to Column expressions representing the updates to be applied.
- * @return The MergeIntoWriter instance with the update action configured.
- */
- def update(map: Map[String, Column]): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(
- UpdateAction(condition, map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq))
- }
-
- /**
- * Specifies an action to delete non-matched rows from the target DataFrame when not matched by
- * the source.
- *
- * @return The MergeIntoWriter instance with the delete action configured.
- */
- def delete(): MergeIntoWriter[T] = {
- mergeIntoWriter.withNewNotMatchedBySourceActions(DeleteAction(condition))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 777baa3e62687..bd47a21a1e09b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -22,13 +22,13 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.toPrettySQL
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.internal.ExpressionUtils.{column, generateAlias}
@@ -53,8 +53,8 @@ class RelationalGroupedDataset protected[sql](
protected[sql] val df: DataFrame,
private[sql] val groupingExprs: Seq[Expression],
groupType: RelationalGroupedDataset.GroupType)
- extends api.RelationalGroupedDataset[Dataset] {
- type RGD = RelationalGroupedDataset
+ extends api.RelationalGroupedDataset {
+
import RelationalGroupedDataset._
import df.sparkSession._
@@ -117,23 +117,14 @@ class RelationalGroupedDataset protected[sql](
columnExprs.map(column)
}
-
- /**
- * Returns a `KeyValueGroupedDataset` where the data is grouped by the grouping expressions
- * of current `RelationalGroupedDataset`.
- *
- * @since 3.0.0
- */
+ /** @inheritdoc */
def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
- val keyEncoder = encoderFor[K]
- val valueEncoder = encoderFor[T]
-
val (qe, groupingAttributes) =
handleGroupingExpression(df.logicalPlan, df.sparkSession, groupingExprs)
new KeyValueGroupedDataset(
- keyEncoder,
- valueEncoder,
+ implicitly[Encoder[K]],
+ implicitly[Encoder[T]],
qe,
df.logicalPlan.output,
groupingAttributes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 55f67da68221f..137dbaed9f00a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.Paths
import java.util.{ServiceLoader, UUID}
+import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
+import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -57,7 +59,7 @@ import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
-import org.apache.spark.util.{CallSite, SparkFileUtils, Utils}
+import org.apache.spark.util.{CallSite, SparkFileUtils, ThreadUtils, Utils}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -92,8 +94,9 @@ class SparkSession private(
@transient private val existingSharedState: Option[SharedState],
@transient private val parentSessionState: Option[SessionState],
@transient private[sql] val extensions: SparkSessionExtensions,
- @transient private[sql] val initialSessionOptions: Map[String, String])
- extends api.SparkSession[Dataset] with Logging { self =>
+ @transient private[sql] val initialSessionOptions: Map[String, String],
+ @transient private val parentManagedJobTags: Map[String, String])
+ extends api.SparkSession with Logging { self =>
// The call site where this SparkSession was constructed.
private val creationSite: CallSite = Utils.getCallSite()
@@ -107,7 +110,12 @@ class SparkSession private(
private[sql] def this(
sc: SparkContext,
initialSessionOptions: java.util.HashMap[String, String]) = {
- this(sc, None, None, applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap)
+ this(
+ sc,
+ existingSharedState = None,
+ parentSessionState = None,
+ applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap,
+ parentManagedJobTags = Map.empty)
}
private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]())
@@ -122,6 +130,18 @@ class SparkSession private(
.getOrElse(SQLConf.getFallbackConf)
})
+ /** Tag to mark all jobs owned by this session. */
+ private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID"
+
+ /**
+ * A map to hold the mapping from user-defined tags to the real tags attached to Jobs.
+ * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`.
+ */
+ @transient
+ private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = {
+ new ConcurrentHashMap(parentManagedJobTags.asJava)
+ }
+
/** @inheritdoc */
def version: String = SPARK_VERSION
@@ -173,16 +193,8 @@ class SparkSession private(
@transient
val sqlContext: SQLContext = new SQLContext(this)
- /**
- * Runtime configuration interface for Spark.
- *
- * This is the interface through which the user can get and set all Spark and Hadoop
- * configurations that are relevant to Spark SQL. When getting the value of a config,
- * this defaults to the value set in the underlying `SparkContext`, if any.
- *
- * @since 2.0.0
- */
- @transient lazy val conf: RuntimeConfig = new RuntimeConfig(sessionState.conf)
+ /** @inheritdoc */
+ @transient lazy val conf: RuntimeConfig = new RuntimeConfigImpl(sessionState.conf)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -243,7 +255,8 @@ class SparkSession private(
Some(sharedState),
parentSessionState = None,
extensions,
- initialSessionOptions)
+ initialSessionOptions,
+ parentManagedJobTags = Map.empty)
}
/**
@@ -264,8 +277,10 @@ class SparkSession private(
Some(sharedState),
Some(sessionState),
extensions,
- Map.empty)
+ Map.empty,
+ managedJobTags.asScala.toMap)
result.sessionState // force copy of SessionState
+ result.managedJobTags // force copy of userDefinedToRealTagsMap
result
}
@@ -480,12 +495,7 @@ class SparkSession private(
| Catalog-related methods |
* ------------------------- */
- /**
- * Interface through which the user may create, drop, alter or query underlying
- * databases, tables, functions etc.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
@transient lazy val catalog: Catalog = new CatalogImpl(self)
/** @inheritdoc */
@@ -649,16 +659,84 @@ class SparkSession private(
artifactManager.addLocalArtifacts(uri.flatMap(Artifact.parseArtifacts))
}
+ /** @inheritdoc */
+ override def addTag(tag: String): Unit = {
+ SparkContext.throwIfInvalidTag(tag)
+ managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag")
+ }
+
+ /** @inheritdoc */
+ override def removeTag(tag: String): Unit = managedJobTags.remove(tag)
+
+ /** @inheritdoc */
+ override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet
+
+ /** @inheritdoc */
+ override def clearTags(): Unit = managedJobTags.clear()
+
/**
- * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
- * `DataFrame`.
- * {{{
- * sparkSession.read.parquet("/path/to/file.parquet")
- * sparkSession.read.schema(schema).json("/path/to/file.json")
- * }}}
+ * Request to interrupt all currently running SQL operations of this session.
*
- * @since 2.0.0
+ * @note Only DataFrame/SQL operations started by this session can be interrupted.
+ *
+ * @note This method will wait up to 60 seconds for the interruption request to be issued.
+
+ * @return Sequence of SQL execution IDs requested to be interrupted.
+
+ * @since 4.0.0
*/
+ override def interruptAll(): Seq[String] =
+ doInterruptTag(sessionJobTag, "as part of cancellation of all jobs")
+
+ /**
+ * Request to interrupt all currently running SQL operations of this session with the given
+ * job tag.
+ *
+ * @note Only DataFrame/SQL operations started by this session can be interrupted.
+ *
+ * @note This method will wait up to 60 seconds for the interruption request to be issued.
+ *
+ * @return Sequence of SQL execution IDs requested to be interrupted.
+
+ * @since 4.0.0
+ */
+ override def interruptTag(tag: String): Seq[String] = {
+ val realTag = managedJobTags.get(tag)
+ if (realTag == null) return Seq.empty
+ doInterruptTag(realTag, s"part of cancelled job tags $tag")
+ }
+
+ private def doInterruptTag(tag: String, reason: String): Seq[String] = {
+ val cancelledTags =
+ sparkContext.cancelJobsWithTagWithFuture(tag, reason)
+
+ ThreadUtils.awaitResult(cancelledTags, 60.seconds)
+ .flatMap(job => Option(job.properties.getProperty(SQLExecution.EXECUTION_ROOT_ID_KEY)))
+ }
+
+ /**
+ * Request to interrupt a SQL operation of this session, given its SQL execution ID.
+ *
+ * @note Only DataFrame/SQL operations started by this session can be interrupted.
+ *
+ * @note This method will wait up to 60 seconds for the interruption request to be issued.
+ *
+ * @return The execution ID requested to be interrupted, as a single-element sequence, or an empty
+ * sequence if the operation is not started by this session.
+ *
+ * @since 4.0.0
+ */
+ override def interruptOperation(operationId: String): Seq[String] = {
+ scala.util.Try(operationId.toLong).toOption match {
+ case Some(executionIdToBeCancelled) =>
+ val tagToBeCancelled = SQLExecution.executionIdJobTag(this, executionIdToBeCancelled)
+ doInterruptTag(tagToBeCancelled, reason = "")
+ case None =>
+ throw new IllegalArgumentException("executionId must be a number in string form.")
+ }
+ }
+
+ /** @inheritdoc */
def read: DataFrameReader = new DataFrameReader(self)
/**
@@ -744,7 +822,7 @@ class SparkSession private(
}
/**
- * Execute a block of code with the this session set as the active session, and restore the
+ * Execute a block of code with this session set as the active session, and restore the
* previous session on completion.
*/
private[sql] def withActive[T](block: => T): T = {
@@ -759,7 +837,8 @@ class SparkSession private(
}
private[sql] def leafNodeDefaultParallelism: Int = {
- conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM).getOrElse(sparkContext.defaultParallelism)
+ sessionState.conf.getConf(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM)
+ .getOrElse(sparkContext.defaultParallelism)
}
private[sql] object Converter extends ColumnNodeToExpressionConverter with Serializable {
@@ -979,7 +1058,12 @@ object SparkSession extends Logging {
loadExtensions(extensions)
applyExtensions(sparkContext, extensions)
- session = new SparkSession(sparkContext, None, None, extensions, options.toMap)
+ session = new SparkSession(sparkContext,
+ existingSharedState = None,
+ parentSessionState = None,
+ extensions,
+ initialSessionOptions = options.toMap,
+ parentManagedJobTags = Map.empty)
setDefaultSession(session)
setActiveSession(session)
registerContextListener(sparkContext)
@@ -1124,13 +1208,13 @@ object SparkSession extends Logging {
private[sql] def getOrCloneSessionWithConfigsOff(
session: SparkSession,
configurations: Seq[ConfigEntry[Boolean]]): SparkSession = {
- val configsEnabled = configurations.filter(session.conf.get[Boolean])
+ val configsEnabled = configurations.filter(session.sessionState.conf.getConf[Boolean])
if (configsEnabled.isEmpty) {
session
} else {
val newSession = session.cloneSession()
configsEnabled.foreach(conf => {
- newSession.conf.set(conf, false)
+ newSession.sessionState.conf.setConf(conf, false)
})
newSession
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index c1c9af2ea4273..bc270e6ac64ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -29,7 +29,7 @@ import org.apache.spark.internal.LogKeys.CLASS_LOADER
import org.apache.spark.security.SocketAuthServer
import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -69,7 +69,10 @@ private[sql] object PythonSQLUtils extends Logging {
// This is needed when generating SQL documentation for built-in functions.
def listBuiltinFunctionInfos(): Array[ExpressionInfo] = {
- FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray
+ (FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)) ++
+ TableFunctionRegistry.functionSet.flatMap(
+ f => TableFunctionRegistry.builtin.lookupFunction(f))).
+ groupBy(_.getName).map(v => v._2.head).toArray
}
private def listAllSQLConfigs(): Seq[(String, String, String, String)] = {
@@ -152,6 +155,9 @@ private[sql] object PythonSQLUtils extends Logging {
def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
+ def binary_search(e: Column, value: Column): Column =
+ Column.internalFn("array_binary_search", e, value)
+
def pandasProduct(e: Column, ignoreNA: Boolean): Column =
Column.internalFn("pandas_product", e, lit(ignoreNA))
@@ -173,6 +179,13 @@ private[sql] object PythonSQLUtils extends Logging {
def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
Column.internalFn("pandas_covar", col1, col2, lit(ddof))
+ /**
+ * A long column that increases one by one.
+ * This is for 'distributed-sequence' default index in pandas API on Spark.
+ */
+ def distributed_sequence_id(): Column =
+ Column.internalFn("distributed_sequence_id")
+
def unresolvedNamedLambdaVariable(name: String): Column =
Column(internal.UnresolvedNamedLambdaVariable.apply(name))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index 89aa7cfd56071..1ee960622fc2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -64,11 +64,12 @@ class ArtifactManager(session: SparkSession) extends Logging {
// The base directory/URI where all artifacts are stored for this `sessionUUID`.
protected[artifact] val (artifactPath, artifactURI): (Path, String) =
(ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID),
- s"$artifactRootURI/${session.sessionUUID}")
+ s"$artifactRootURI${File.separator}${session.sessionUUID}")
// The base directory/URI where all class file artifacts are stored for this `sessionUUID`.
protected[artifact] val (classDir, classURI): (Path, String) =
- (ArtifactUtils.concatenatePaths(artifactPath, "classes"), s"$artifactURI/classes/")
+ (ArtifactUtils.concatenatePaths(artifactPath, "classes"),
+ s"$artifactURI${File.separator}classes${File.separator}")
protected[artifact] val state: JobArtifactState =
JobArtifactState(session.sessionUUID, Option(classURI))
@@ -112,6 +113,14 @@ class ArtifactManager(session: SparkSession) extends Logging {
}
}
+ private def normalizePath(path: Path): Path = {
+ // Convert the path to a string with the current system's separator
+ val normalizedPathString = path.toString
+ .replace('/', File.separatorChar)
+ .replace('\\', File.separatorChar)
+ // Convert the normalized string back to a Path object
+ Paths.get(normalizedPathString).normalize()
+ }
/**
* Add and prepare a staged artifact (i.e an artifact that has been rebuilt locally from bytes
* over the wire) for use.
@@ -128,14 +137,14 @@ class ArtifactManager(session: SparkSession) extends Logging {
deleteStagedFile: Boolean = true
): Unit = JobArtifactSet.withActiveJobArtifactState(state) {
require(!remoteRelativePath.isAbsolute)
-
- if (remoteRelativePath.startsWith(s"cache${File.separator}")) {
+ val normalizedRemoteRelativePath = normalizePath(remoteRelativePath)
+ if (normalizedRemoteRelativePath.startsWith(s"cache${File.separator}")) {
val tmpFile = serverLocalStagingPath.toFile
Utils.tryWithSafeFinallyAndFailureCallbacks {
val blockManager = session.sparkContext.env.blockManager
val blockId = CacheId(
sessionUUID = session.sessionUUID,
- hash = remoteRelativePath.toString.stripPrefix(s"cache${File.separator}"))
+ hash = normalizedRemoteRelativePath.toString.stripPrefix(s"cache${File.separator}"))
val updater = blockManager.TempFileBasedBlockStoreUpdater(
blockId = blockId,
level = StorageLevel.MEMORY_AND_DISK_SER,
@@ -145,11 +154,11 @@ class ArtifactManager(session: SparkSession) extends Logging {
tellMaster = false)
updater.save()
}(catchBlock = { tmpFile.delete() })
- } else if (remoteRelativePath.startsWith(s"classes${File.separator}")) {
+ } else if (normalizedRemoteRelativePath.startsWith(s"classes${File.separator}")) {
// Move class files to the right directory.
val target = ArtifactUtils.concatenatePaths(
classDir,
- remoteRelativePath.toString.stripPrefix(s"classes${File.separator}"))
+ normalizedRemoteRelativePath.toString.stripPrefix(s"classes${File.separator}"))
// Allow overwriting class files to capture updates to classes.
// This is required because the client currently sends all the class files in each class file
// transfer.
@@ -159,7 +168,7 @@ class ArtifactManager(session: SparkSession) extends Logging {
allowOverwrite = true,
deleteSource = deleteStagedFile)
} else {
- val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath)
+ val target = ArtifactUtils.concatenatePaths(artifactPath, normalizedRemoteRelativePath)
// Disallow overwriting with modified version
if (Files.exists(target)) {
// makes the query idempotent
@@ -167,30 +176,30 @@ class ArtifactManager(session: SparkSession) extends Logging {
return
}
- throw new RuntimeException(s"Duplicate Artifact: $remoteRelativePath. " +
+ throw new RuntimeException(s"Duplicate Artifact: $normalizedRemoteRelativePath. " +
"Artifacts cannot be overwritten.")
}
transferFile(serverLocalStagingPath, target, deleteSource = deleteStagedFile)
// This URI is for Spark file server that starts with "spark://".
val uri = s"$artifactURI/${Utils.encodeRelativeUnixPathToURIRawPath(
- FilenameUtils.separatorsToUnix(remoteRelativePath.toString))}"
+ FilenameUtils.separatorsToUnix(normalizedRemoteRelativePath.toString))}"
- if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
+ if (normalizedRemoteRelativePath.startsWith(s"jars${File.separator}")) {
session.sparkContext.addJar(uri)
jarsList.add(target)
- } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
+ } else if (normalizedRemoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
session.sparkContext.addFile(uri)
- val stringRemotePath = remoteRelativePath.toString
+ val stringRemotePath = normalizedRemoteRelativePath.toString
if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
".egg") || stringRemotePath.endsWith(".jar")) {
pythonIncludeList.add(target.getFileName.toString)
}
- } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) {
+ } else if (normalizedRemoteRelativePath.startsWith(s"archives${File.separator}")) {
val canonicalUri =
fragment.map(Utils.getUriBuilder(new URI(uri)).fragment).getOrElse(new URI(uri))
session.sparkContext.addArchive(canonicalUri.toString)
- } else if (remoteRelativePath.startsWith(s"files${File.separator}")) {
+ } else if (normalizedRemoteRelativePath.startsWith(s"files${File.separator}")) {
session.sparkContext.addFile(uri)
}
}
@@ -301,20 +310,21 @@ class ArtifactManager(session: SparkSession) extends Logging {
def uploadArtifactToFs(
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
+ val normalizedRemoteRelativePath = normalizePath(remoteRelativePath)
val hadoopConf = session.sparkContext.hadoopConfiguration
assert(
- remoteRelativePath.startsWith(
+ normalizedRemoteRelativePath.startsWith(
ArtifactManager.forwardToFSPrefix + File.separator))
val destFSPath = new FSPath(
Paths
- .get("/")
- .resolve(remoteRelativePath.subpath(1, remoteRelativePath.getNameCount))
+ .get(File.separator)
+ .resolve(normalizedRemoteRelativePath.subpath(1, normalizedRemoteRelativePath.getNameCount))
.toString)
val localPath = serverLocalStagingPath
val fs = destFSPath.getFileSystem(hadoopConf)
if (fs.isInstanceOf[LocalFileSystem]) {
val allowDestLocalConf =
- session.conf.get(SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL)
+ session.sessionState.conf.getConf(SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL)
.getOrElse(
session.conf.get("spark.connect.copyFromLocalToFs.allowDestLocal").contains("true"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index 676be7fe41cbc..c39018ff06fca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -14,657 +14,153 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql.catalog
-import scala.jdk.CollectionConverters._
+import java.util
-import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset}
+import org.apache.spark.sql.{api, DataFrame, Dataset}
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.types.StructType
-import org.apache.spark.storage.StorageLevel
-
-/**
- * Catalog interface for Spark. To access this, use `SparkSession.catalog`.
- *
- * @since 2.0.0
- */
-@Stable
-abstract class Catalog {
-
- /**
- * Returns the current database (namespace) in this session.
- *
- * @since 2.0.0
- */
- def currentDatabase: String
-
- /**
- * Sets the current database (namespace) in this session.
- *
- * @since 2.0.0
- */
- def setCurrentDatabase(dbName: String): Unit
-
- /**
- * Returns a list of databases (namespaces) available within the current catalog.
- *
- * @since 2.0.0
- */
- def listDatabases(): Dataset[Database]
-
- /**
- * Returns a list of databases (namespaces) which name match the specify pattern and
- * available within the current catalog.
- *
- * @since 3.5.0
- */
- def listDatabases(pattern: String): Dataset[Database]
-
- /**
- * Returns a list of tables/views in the current database (namespace).
- * This includes all temporary views.
- *
- * @since 2.0.0
- */
- def listTables(): Dataset[Table]
-
- /**
- * Returns a list of tables/views in the specified database (namespace) (the name can be qualified
- * with catalog).
- * This includes all temporary views.
- *
- * @since 2.0.0
- */
- @throws[AnalysisException]("database does not exist")
- def listTables(dbName: String): Dataset[Table]
-
- /**
- * Returns a list of tables/views in the specified database (namespace)
- * which name match the specify pattern (the name can be qualified with catalog).
- * This includes all temporary views.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database does not exist")
- def listTables(dbName: String, pattern: String): Dataset[Table]
-
- /**
- * Returns a list of functions registered in the current database (namespace).
- * This includes all temporary functions.
- *
- * @since 2.0.0
- */
- def listFunctions(): Dataset[Function]
-
- /**
- * Returns a list of functions registered in the specified database (namespace) (the name can be
- * qualified with catalog).
- * This includes all built-in and temporary functions.
- *
- * @since 2.0.0
- */
- @throws[AnalysisException]("database does not exist")
- def listFunctions(dbName: String): Dataset[Function]
-
- /**
- * Returns a list of functions registered in the specified database (namespace)
- * which name match the specify pattern (the name can be qualified with catalog).
- * This includes all built-in and temporary functions.
- *
- * @since 3.5.0
- */
- @throws[AnalysisException]("database does not exist")
- def listFunctions(dbName: String, pattern: String): Dataset[Function]
- /**
- * Returns a list of columns for the given table/view or temporary view.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view. It
- * follows the same resolution rule with SQL: search for temp views first then
- * table/views in the current database (namespace).
- * @since 2.0.0
- */
- @throws[AnalysisException]("table does not exist")
- def listColumns(tableName: String): Dataset[Column]
+/** @inheritdoc */
+abstract class Catalog extends api.Catalog {
+ /** @inheritdoc */
+ override def listDatabases(): Dataset[Database]
- /**
- * Returns a list of columns for the given table/view in the specified database under the Hive
- * Metastore.
- *
- * To list columns for table/view in other catalogs, please use `listColumns(tableName)` with
- * qualified table/view name instead.
- *
- * @param dbName is an unqualified name that designates a database.
- * @param tableName is an unqualified name that designates a table/view.
- * @since 2.0.0
- */
- @throws[AnalysisException]("database or table does not exist")
- def listColumns(dbName: String, tableName: String): Dataset[Column]
+ /** @inheritdoc */
+ override def listDatabases(pattern: String): Dataset[Database]
- /**
- * Get the database (namespace) with the specified name (can be qualified with catalog). This
- * throws an AnalysisException when the database (namespace) cannot be found.
- *
- * @since 2.1.0
- */
- @throws[AnalysisException]("database does not exist")
- def getDatabase(dbName: String): Database
+ /** @inheritdoc */
+ override def listTables(): Dataset[Table]
- /**
- * Get the table or view with the specified name. This table can be a temporary view or a
- * table/view. This throws an AnalysisException when no Table can be found.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view. It
- * follows the same resolution rule with SQL: search for temp views first then
- * table/views in the current database (namespace).
- * @since 2.1.0
- */
- @throws[AnalysisException]("table does not exist")
- def getTable(tableName: String): Table
+ /** @inheritdoc */
+ override def listTables(dbName: String): Dataset[Table]
- /**
- * Get the table or view with the specified name in the specified database under the Hive
- * Metastore. This throws an AnalysisException when no Table can be found.
- *
- * To get table/view in other catalogs, please use `getTable(tableName)` with qualified table/view
- * name instead.
- *
- * @since 2.1.0
- */
- @throws[AnalysisException]("database or table does not exist")
- def getTable(dbName: String, tableName: String): Table
+ /** @inheritdoc */
+ override def listTables(dbName: String, pattern: String): Dataset[Table]
- /**
- * Get the function with the specified name. This function can be a temporary function or a
- * function. This throws an AnalysisException when the function cannot be found.
- *
- * @param functionName is either a qualified or unqualified name that designates a function. It
- * follows the same resolution rule with SQL: search for built-in/temp
- * functions first then functions in the current database (namespace).
- * @since 2.1.0
- */
- @throws[AnalysisException]("function does not exist")
- def getFunction(functionName: String): Function
+ /** @inheritdoc */
+ override def listFunctions(): Dataset[Function]
- /**
- * Get the function with the specified name in the specified database under the Hive Metastore.
- * This throws an AnalysisException when the function cannot be found.
- *
- * To get functions in other catalogs, please use `getFunction(functionName)` with qualified
- * function name instead.
- *
- * @param dbName is an unqualified name that designates a database.
- * @param functionName is an unqualified name that designates a function in the specified database
- * @since 2.1.0
- */
- @throws[AnalysisException]("database or function does not exist")
- def getFunction(dbName: String, functionName: String): Function
+ /** @inheritdoc */
+ override def listFunctions(dbName: String): Dataset[Function]
- /**
- * Check if the database (namespace) with the specified name exists (the name can be qualified
- * with catalog).
- *
- * @since 2.1.0
- */
- def databaseExists(dbName: String): Boolean
+ /** @inheritdoc */
+ override def listFunctions(dbName: String, pattern: String): Dataset[Function]
- /**
- * Check if the table or view with the specified name exists. This can either be a temporary
- * view or a table/view.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view. It
- * follows the same resolution rule with SQL: search for temp views first then
- * table/views in the current database (namespace).
- * @since 2.1.0
- */
- def tableExists(tableName: String): Boolean
+ /** @inheritdoc */
+ override def listColumns(tableName: String): Dataset[Column]
- /**
- * Check if the table or view with the specified name exists in the specified database under the
- * Hive Metastore.
- *
- * To check existence of table/view in other catalogs, please use `tableExists(tableName)` with
- * qualified table/view name instead.
- *
- * @param dbName is an unqualified name that designates a database.
- * @param tableName is an unqualified name that designates a table.
- * @since 2.1.0
- */
- def tableExists(dbName: String, tableName: String): Boolean
+ /** @inheritdoc */
+ override def listColumns(dbName: String, tableName: String): Dataset[Column]
- /**
- * Check if the function with the specified name exists. This can either be a temporary function
- * or a function.
- *
- * @param functionName is either a qualified or unqualified name that designates a function. It
- * follows the same resolution rule with SQL: search for built-in/temp
- * functions first then functions in the current database (namespace).
- * @since 2.1.0
- */
- def functionExists(functionName: String): Boolean
+ /** @inheritdoc */
+ override def createTable(tableName: String, path: String): DataFrame
- /**
- * Check if the function with the specified name exists in the specified database under the
- * Hive Metastore.
- *
- * To check existence of functions in other catalogs, please use `functionExists(functionName)`
- * with qualified function name instead.
- *
- * @param dbName is an unqualified name that designates a database.
- * @param functionName is an unqualified name that designates a function.
- * @since 2.1.0
- */
- def functionExists(dbName: String, functionName: String): Boolean
+ /** @inheritdoc */
+ override def createTable(tableName: String, path: String, source: String): DataFrame
- /**
- * Creates a table from the given path and returns the corresponding DataFrame.
- * It will use the default data source configured by spark.sql.sources.default.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.0.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(tableName: String, path: String): DataFrame = {
- createTable(tableName, path)
- }
-
- /**
- * Creates a table from the given path and returns the corresponding DataFrame.
- * It will use the default data source configured by spark.sql.sources.default.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.2.0
- */
- def createTable(tableName: String, path: String): DataFrame
-
- /**
- * Creates a table from the given path based on a data source and returns the corresponding
- * DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.0.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(tableName: String, path: String, source: String): DataFrame = {
- createTable(tableName, path, source)
- }
-
- /**
- * Creates a table from the given path based on a data source and returns the corresponding
- * DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.2.0
- */
- def createTable(tableName: String, path: String, source: String): DataFrame
-
- /**
- * Creates a table from the given path based on a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.0.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, options)
- }
+ options: Map[String, String]): DataFrame
- /**
- * Creates a table based on the dataset in a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.2.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, options.asScala.toMap)
- }
+ description: String,
+ options: Map[String, String]): DataFrame
- /**
- * (Scala-specific)
- * Creates a table from the given path based on a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.0.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- options: Map[String, String]): DataFrame = {
- createTable(tableName, source, options)
- }
+ schema: StructType,
+ options: Map[String, String]): DataFrame
- /**
- * (Scala-specific)
- * Creates a table based on the dataset in a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.2.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
+ schema: StructType,
+ description: String,
options: Map[String, String]): DataFrame
- /**
- * Create a table from the given path based on a data source, a schema and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.0.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+ /** @inheritdoc */
+ override def listCatalogs(): Dataset[CatalogMetadata]
+
+ /** @inheritdoc */
+ override def listCatalogs(pattern: String): Dataset[CatalogMetadata]
+
+ /** @inheritdoc */
+ override def createExternalTable(tableName: String, path: String): DataFrame =
+ super.createExternalTable(tableName, path)
+
+ /** @inheritdoc */
+ override def createExternalTable(tableName: String, path: String, source: String): DataFrame =
+ super.createExternalTable(tableName, path, source)
+
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
- schema: StructType,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, schema, options)
- }
+ options: util.Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, options)
- /**
- * Creates a table based on the dataset in a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 3.1.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- description: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(
- tableName,
- source = source,
- description = description,
- options = options.asScala.toMap
- )
- }
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, options)
- /**
- * (Scala-specific)
- * Creates a table based on the dataset in a data source and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 3.1.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
- description: String,
- options: Map[String, String]): DataFrame
+ options: Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, options)
- /**
- * Create a table based on the dataset in a data source, a schema and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.2.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
schema: StructType,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(tableName, source, schema, options.asScala.toMap)
- }
+ options: util.Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, schema, options)
- /**
- * (Scala-specific)
- * Create a table from the given path based on a data source, a schema and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.0.0
- */
- @deprecated("use createTable instead.", "2.2.0")
- def createExternalTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
- schema: StructType,
- options: Map[String, String]): DataFrame = {
- createTable(tableName, source, schema, options)
- }
+ description: String,
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, description, options)
- /**
- * (Scala-specific)
- * Create a table based on the dataset in a data source, a schema and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 2.2.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
schema: StructType,
- options: Map[String, String]): DataFrame
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, schema, options)
- /**
- * Create a table based on the dataset in a data source, a schema and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 3.1.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createExternalTable(
tableName: String,
source: String,
schema: StructType,
- description: String,
- options: java.util.Map[String, String]): DataFrame = {
- createTable(
- tableName,
- source = source,
- schema = schema,
- description = description,
- options = options.asScala.toMap
- )
- }
+ options: Map[String, String]): DataFrame =
+ super.createExternalTable(tableName, source, schema, options)
- /**
- * (Scala-specific)
- * Create a table based on the dataset in a data source, a schema and a set of options.
- * Then, returns the corresponding DataFrame.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in
- * the current database.
- * @since 3.1.0
- */
- def createTable(
+ /** @inheritdoc */
+ override def createTable(
tableName: String,
source: String,
schema: StructType,
description: String,
- options: Map[String, String]): DataFrame
-
- /**
- * Drops the local temporary view with the given view name in the catalog.
- * If the view has been cached before, then it will also be uncached.
- *
- * Local temporary view is session-scoped. Its lifetime is the lifetime of the session that
- * created it, i.e. it will be automatically dropped when the session terminates. It's not
- * tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view.
- *
- * Note that, the return type of this method was Unit in Spark 2.0, but changed to Boolean
- * in Spark 2.1.
- *
- * @param viewName the name of the temporary view to be dropped.
- * @return true if the view is dropped successfully, false otherwise.
- * @since 2.0.0
- */
- def dropTempView(viewName: String): Boolean
-
- /**
- * Drops the global temporary view with the given view name in the catalog.
- * If the view has been cached before, then it will also be uncached.
- *
- * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
- * i.e. it will be automatically dropped when the application terminates. It's tied to a system
- * preserved database `global_temp`, and we must use the qualified name to refer a global temp
- * view, e.g. `SELECT * FROM global_temp.view1`.
- *
- * @param viewName the unqualified name of the temporary view to be dropped.
- * @return true if the view is dropped successfully, false otherwise.
- * @since 2.1.0
- */
- def dropGlobalTempView(viewName: String): Boolean
-
- /**
- * Recovers all the partitions in the directory of a table and update the catalog.
- * Only works with a partitioned table, and not a view.
- *
- * @param tableName is either a qualified or unqualified name that designates a table.
- * If no database identifier is provided, it refers to a table in the
- * current database.
- * @since 2.1.1
- */
- def recoverPartitions(tableName: String): Unit
-
- /**
- * Returns true if the table is currently cached in-memory.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view.
- * If no database identifier is provided, it refers to a temporary view or
- * a table/view in the current database.
- * @since 2.0.0
- */
- def isCached(tableName: String): Boolean
-
- /**
- * Caches the specified table in-memory.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view.
- * If no database identifier is provided, it refers to a temporary view or
- * a table/view in the current database.
- * @since 2.0.0
- */
- def cacheTable(tableName: String): Unit
-
- /**
- * Caches the specified table with the given storage level.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view.
- * If no database identifier is provided, it refers to a temporary view or
- * a table/view in the current database.
- * @param storageLevel storage level to cache table.
- * @since 2.3.0
- */
- def cacheTable(tableName: String, storageLevel: StorageLevel): Unit
-
-
- /**
- * Removes the specified table from the in-memory cache.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view.
- * If no database identifier is provided, it refers to a temporary view or
- * a table/view in the current database.
- * @since 2.0.0
- */
- def uncacheTable(tableName: String): Unit
-
- /**
- * Removes all cached tables from the in-memory cache.
- *
- * @since 2.0.0
- */
- def clearCache(): Unit
-
- /**
- * Invalidates and refreshes all the cached data and metadata of the given table. For performance
- * reasons, Spark SQL or the external data source library it uses might cache certain metadata
- * about a table, such as the location of blocks. When those change outside of Spark SQL, users
- * should call this function to invalidate the cache.
- *
- * If this table is cached as an InMemoryRelation, drop the original cached version and make the
- * new version cached lazily.
- *
- * @param tableName is either a qualified or unqualified name that designates a table/view.
- * If no database identifier is provided, it refers to a temporary view or
- * a table/view in the current database.
- * @since 2.0.0
- */
- def refreshTable(tableName: String): Unit
-
- /**
- * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset`
- * that contains the given data source path. Path matching is by checking for sub-directories,
- * i.e. "/" would invalidate everything that is cached and "/test/parent" would invalidate
- * everything that is a subdirectory of "/test/parent".
- *
- * @since 2.0.0
- */
- def refreshByPath(path: String): Unit
-
- /**
- * Returns the current catalog in this session.
- *
- * @since 3.4.0
- */
- def currentCatalog(): String
-
- /**
- * Sets the current catalog in this session.
- *
- * @since 3.4.0
- */
- def setCurrentCatalog(catalogName: String): Unit
-
- /**
- * Returns a list of catalogs available in this session.
- *
- * @since 3.4.0
- */
- def listCatalogs(): Dataset[CatalogMetadata]
-
- /**
- * Returns a list of catalogs which name match the specify pattern and available in this session.
- *
- * @since 3.5.0
- */
- def listCatalogs(pattern: String): Dataset[CatalogMetadata]
+ options: util.Map[String, String]): DataFrame =
+ super.createTable(tableName, source, schema, description, options)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala
deleted file mode 100644
index 31aee5a43ef47..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala
+++ /dev/null
@@ -1,224 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalog
-
-import javax.annotation.Nullable
-
-import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.catalyst.DefinedByConstructorParams
-
-
-// Note: all classes here are expected to be wrapped in Datasets and so must extend
-// DefinedByConstructorParams for the catalog to be able to create encoders for them.
-
-/**
- * A catalog in Spark, as returned by the `listCatalogs` method defined in [[Catalog]].
- *
- * @param name name of the catalog
- * @param description description of the catalog
- * @since 3.4.0
- */
-class CatalogMetadata(
- val name: String,
- @Nullable val description: String)
- extends DefinedByConstructorParams {
-
- override def toString: String =
- s"Catalog[name='$name', ${Option(description).map(d => s"description='$d'").getOrElse("")}]"
-}
-
-/**
- * A database in Spark, as returned by the `listDatabases` method defined in [[Catalog]].
- *
- * @param name name of the database.
- * @param catalog name of the catalog that the table belongs to.
- * @param description description of the database.
- * @param locationUri path (in the form of a uri) to data files.
- * @since 2.0.0
- */
-@Stable
-class Database(
- val name: String,
- @Nullable val catalog: String,
- @Nullable val description: String,
- val locationUri: String)
- extends DefinedByConstructorParams {
-
- def this(name: String, description: String, locationUri: String) = {
- this(name, null, description, locationUri)
- }
-
- override def toString: String = {
- "Database[" +
- s"name='$name', " +
- Option(catalog).map { c => s"catalog='$c', " }.getOrElse("") +
- Option(description).map { d => s"description='$d', " }.getOrElse("") +
- s"path='$locationUri']"
- }
-
-}
-
-
-/**
- * A table in Spark, as returned by the `listTables` method in [[Catalog]].
- *
- * @param name name of the table.
- * @param catalog name of the catalog that the table belongs to.
- * @param namespace the namespace that the table belongs to.
- * @param description description of the table.
- * @param tableType type of the table (e.g. view, table).
- * @param isTemporary whether the table is a temporary table.
- * @since 2.0.0
- */
-@Stable
-class Table(
- val name: String,
- @Nullable val catalog: String,
- @Nullable val namespace: Array[String],
- @Nullable val description: String,
- val tableType: String,
- val isTemporary: Boolean)
- extends DefinedByConstructorParams {
-
- if (namespace != null) {
- assert(namespace.forall(_ != null))
- }
-
- def this(
- name: String,
- database: String,
- description: String,
- tableType: String,
- isTemporary: Boolean) = {
- this(name, null, if (database != null) Array(database) else null,
- description, tableType, isTemporary)
- }
-
- def database: String = {
- if (namespace != null && namespace.length == 1) namespace(0) else null
- }
-
- override def toString: String = {
- "Table[" +
- s"name='$name', " +
- Option(catalog).map { d => s"catalog='$d', " }.getOrElse("") +
- Option(database).map { d => s"database='$d', " }.getOrElse("") +
- Option(description).map { d => s"description='$d', " }.getOrElse("") +
- s"tableType='$tableType', " +
- s"isTemporary='$isTemporary']"
- }
-
-}
-
-
-/**
- * A column in Spark, as returned by `listColumns` method in [[Catalog]].
- *
- * @param name name of the column.
- * @param description description of the column.
- * @param dataType data type of the column.
- * @param nullable whether the column is nullable.
- * @param isPartition whether the column is a partition column.
- * @param isBucket whether the column is a bucket column.
- * @param isCluster whether the column is a clustering column.
- * @since 2.0.0
- */
-@Stable
-class Column(
- val name: String,
- @Nullable val description: String,
- val dataType: String,
- val nullable: Boolean,
- val isPartition: Boolean,
- val isBucket: Boolean,
- val isCluster: Boolean)
- extends DefinedByConstructorParams {
-
- def this(
- name: String,
- description: String,
- dataType: String,
- nullable: Boolean,
- isPartition: Boolean,
- isBucket: Boolean) = {
- this(name, description, dataType, nullable, isPartition, isBucket, isCluster = false)
- }
-
- override def toString: String = {
- "Column[" +
- s"name='$name', " +
- Option(description).map { d => s"description='$d', " }.getOrElse("") +
- s"dataType='$dataType', " +
- s"nullable='$nullable', " +
- s"isPartition='$isPartition', " +
- s"isBucket='$isBucket', " +
- s"isCluster='$isCluster']"
- }
-
-}
-
-
-/**
- * A user-defined function in Spark, as returned by `listFunctions` method in [[Catalog]].
- *
- * @param name name of the function.
- * @param catalog name of the catalog that the table belongs to.
- * @param namespace the namespace that the table belongs to.
- * @param description description of the function; description can be null.
- * @param className the fully qualified class name of the function.
- * @param isTemporary whether the function is a temporary function or not.
- * @since 2.0.0
- */
-@Stable
-class Function(
- val name: String,
- @Nullable val catalog: String,
- @Nullable val namespace: Array[String],
- @Nullable val description: String,
- val className: String,
- val isTemporary: Boolean)
- extends DefinedByConstructorParams {
-
- if (namespace != null) {
- assert(namespace.forall(_ != null))
- }
-
- def this(
- name: String,
- database: String,
- description: String,
- className: String,
- isTemporary: Boolean) = {
- this(name, null, if (database != null) Array(database) else null,
- description, className, isTemporary)
- }
-
- def database: String = {
- if (namespace != null && namespace.length == 1) namespace(0) else null
- }
-
- override def toString: String = {
- "Function[" +
- s"name='$name', " +
- Option(database).map { d => s"database='$d', " }.getOrElse("") +
- Option(description).map { d => s"description='$d', " }.getOrElse("") +
- s"className='$className', " +
- s"isTemporary='$isTemporary']"
- }
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala
new file mode 100644
index 0000000000000..c7320d350a7ff
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.jdk.CollectionConverters.IteratorHasAsScala
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
+import org.apache.spark.sql.catalyst.plans.logical.{Call, LocalRelation, LogicalPlan, MultiResult}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
+import org.apache.spark.sql.connector.read.{LocalScan, Scan}
+import org.apache.spark.util.ArrayImplicits._
+
+class InvokeProcedures(session: SparkSession) extends Rule[LogicalPlan] {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case c: Call if c.resolved && c.bound && c.execute && c.checkArgTypes().isSuccess =>
+ session.sessionState.optimizer.execute(c) match {
+ case Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) =>
+ invoke(procedure, args)
+ case _ =>
+ throw SparkException.internalError("Unexpected plan for optimized CALL statement")
+ }
+ }
+
+ private def invoke(procedure: BoundProcedure, args: Seq[Expression]): LogicalPlan = {
+ val input = toInternalRow(args)
+ val scanIterator = procedure.call(input)
+ val relations = scanIterator.asScala.map(toRelation).toSeq
+ relations match {
+ case Nil => LocalRelation(Nil)
+ case Seq(relation) => relation
+ case _ => MultiResult(relations)
+ }
+ }
+
+ private def toRelation(scan: Scan): LogicalPlan = scan match {
+ case s: LocalScan =>
+ val attrs = DataTypeUtils.toAttributes(s.readSchema)
+ val data = s.rows.toImmutableArraySeq
+ LocalRelation(attrs, data)
+ case _ =>
+ throw SparkException.internalError(
+ s"Only local scans are temporarily supported as procedure output: ${scan.getClass.getName}")
+ }
+
+ private def toInternalRow(args: Seq[Expression]): InternalRow = {
+ require(args.forall(_.foldable), "args must be foldable")
+ val values = args.map(_.eval()).toArray
+ new GenericInternalRow(values)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index d569f1ed484cc..02ad2e79a5645 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, ResolveDefaultColumns => DefaultCols}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
-import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, LookupCatalog, SupportsNamespaces, V1Table}
+import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, DelegatingCatalogExtension, LookupCatalog, SupportsNamespaces, V1Table}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command._
@@ -284,10 +284,20 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
case AnalyzeColumn(ResolvedV1TableOrViewIdentifier(ident), columnNames, allColumns) =>
AnalyzeColumnCommand(ident, columnNames, allColumns)
- case RepairTable(ResolvedV1TableIdentifier(ident), addPartitions, dropPartitions) =>
+ // V2 catalog doesn't support REPAIR TABLE yet, we must use v1 command here.
+ case RepairTable(
+ ResolvedV1TableIdentifierInSessionCatalog(ident),
+ addPartitions,
+ dropPartitions) =>
RepairTableCommand(ident, addPartitions, dropPartitions)
- case LoadData(ResolvedV1TableIdentifier(ident), path, isLocal, isOverwrite, partition) =>
+ // V2 catalog doesn't support LOAD DATA yet, we must use v1 command here.
+ case LoadData(
+ ResolvedV1TableIdentifierInSessionCatalog(ident),
+ path,
+ isLocal,
+ isOverwrite,
+ partition) =>
LoadDataCommand(
ident,
path,
@@ -336,7 +346,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
}
ShowColumnsCommand(db, v1TableName, output)
- case RecoverPartitions(ResolvedV1TableIdentifier(ident)) =>
+ // V2 catalog doesn't support RECOVER PARTITIONS yet, we must use v1 command here.
+ case RecoverPartitions(ResolvedV1TableIdentifierInSessionCatalog(ident)) =>
RepairTableCommand(
ident,
enableAddPartitions = true,
@@ -364,8 +375,9 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
purge,
retainData = false)
+ // V2 catalog doesn't support setting serde properties yet, we must use v1 command here.
case SetTableSerDeProperties(
- ResolvedV1TableIdentifier(ident),
+ ResolvedV1TableIdentifierInSessionCatalog(ident),
serdeClassName,
serdeProperties,
partitionSpec) =>
@@ -380,10 +392,10 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
// V2 catalog doesn't support setting partition location yet, we must use v1 command here.
case SetTableLocation(
- ResolvedTable(catalog, _, t: V1Table, _),
+ ResolvedV1TableIdentifierInSessionCatalog(ident),
Some(partitionSpec),
- location) if isSessionCatalog(catalog) =>
- AlterTableSetLocationCommand(t.v1Table.identifier, Some(partitionSpec), location)
+ location) =>
+ AlterTableSetLocationCommand(ident, Some(partitionSpec), location)
case AlterViewAs(ResolvedViewIdentifier(ident), originalText, query) =>
AlterViewAsCommand(ident, originalText, query)
@@ -600,6 +612,14 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
}
}
+ object ResolvedV1TableIdentifierInSessionCatalog {
+ def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match {
+ case ResolvedTable(catalog, _, t: V1Table, _) if isSessionCatalog(catalog) =>
+ Some(t.catalogTable.identifier)
+ case _ => None
+ }
+ }
+
object ResolvedV1TableOrViewIdentifier {
def unapply(resolved: LogicalPlan): Option[TableIdentifier] = resolved match {
case ResolvedV1TableIdentifier(ident) => Some(ident)
@@ -684,7 +704,8 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
}
private def supportsV1Command(catalog: CatalogPlugin): Boolean = {
- isSessionCatalog(catalog) &&
- SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty
+ isSessionCatalog(catalog) && (
+ SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty ||
+ catalog.isInstanceOf[DelegatingCatalogExtension])
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala
new file mode 100644
index 0000000000000..d71237ca15ec3
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTranspose.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, IsNotNull, Literal, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, Limit, LogicalPlan, Project, Sort, Transpose}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{AtomicType, DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+
+/**
+ * Rule that resolves and transforms an `UnresolvedTranspose` logical plan into a `Transpose`
+ * logical plan, which effectively transposes a DataFrame by turning rows into columns based
+ * on a specified index column.
+ *
+ * The high-level logic for the transpose operation is as follows:
+ * - If the index column is not provided, the first column of the DataFrame is used as the
+ * default index column.
+ * - The index column is cast to `StringType` to ensure consistent column naming.
+ * - Non-index columns are cast to a common data type, determined by finding the least
+ * common type that can accommodate all non-index columns.
+ * - The data is sorted by the index column, and rows with `null` index values are excluded
+ * from the transpose operation.
+ * - The transposed DataFrame is constructed by turning the original rows into columns, with
+ * the index column values becoming the new column names and the non-index column values
+ * populating the transposed data.
+ */
+class ResolveTranspose(sparkSession: SparkSession) extends Rule[LogicalPlan] {
+
+ private def leastCommonType(dataTypes: Seq[DataType]): DataType = {
+ if (dataTypes.isEmpty) {
+ StringType
+ } else {
+ dataTypes.reduce { (dt1, dt2) =>
+ AnsiTypeCoercion.findTightestCommonType(dt1, dt2).getOrElse {
+ throw new AnalysisException(
+ errorClass = "TRANSPOSE_NO_LEAST_COMMON_TYPE",
+ messageParameters = Map(
+ "dt1" -> dt1.sql,
+ "dt2" -> dt2.sql)
+ )
+ }
+ }
+ }
+ }
+
+ private def transposeMatrix(
+ fullCollectedRows: Array[InternalRow],
+ nonIndexColumnNames: Seq[String],
+ nonIndexColumnDataTypes: Seq[DataType]): Array[Array[Any]] = {
+ val numTransposedRows = fullCollectedRows.head.numFields - 1
+ val numTransposedCols = fullCollectedRows.length + 1
+ val finalMatrix = Array.ofDim[Any](numTransposedRows, numTransposedCols)
+
+ // Example of the original DataFrame:
+ // +---+-----+-----+
+ // | id|col1 |col2 |
+ // +---+-----+-----+
+ // | 1| 10 | 20 |
+ // | 2| 30 | 40 |
+ // +---+-----+-----+
+ //
+ // After transposition, the finalMatrix will look like:
+ // [
+ // ["col1", 10, 30], // Transposed row for col1
+ // ["col2", 20, 40] // Transposed row for col2
+ // ]
+
+ for (i <- 0 until numTransposedRows) {
+ // Insert non-index column name as the first element in each transposed row
+ finalMatrix(i)(0) = UTF8String.fromString(nonIndexColumnNames(i))
+
+ for (j <- 1 until numTransposedCols) {
+ // Insert the transposed data
+
+ // Example: If j = 2, then row = fullCollectedRows(1)
+ // This corresponds to the second row of the original DataFrame: InternalRow(2, 30, 40)
+ val row = fullCollectedRows(j - 1)
+
+ // Example: If i = 0 (for "col1"), and j = 2,
+ // then finalMatrix(0)(2) corresponds to row.get(1, nonIndexColumnDataTypes(0)),
+ // which accesses the value 30 from InternalRow(2, 30, 40)
+ finalMatrix(i)(j) = row.get(i + 1, nonIndexColumnDataTypes(i))
+ }
+ }
+ finalMatrix
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
+ _.containsPattern(TreePattern.UNRESOLVED_TRANSPOSE)) {
+ case t @ UnresolvedTranspose(indices, child) if child.resolved && indices.forall(_.resolved) =>
+ assert(indices.length == 0 || indices.length == 1,
+ "The number of index columns should be either 0 or 1.")
+
+ // Handle empty frame with no column headers
+ if (child.output.isEmpty) {
+ return Transpose(Seq.empty)
+ }
+
+ // Use the first column as index column if not provided
+ val inferredIndexColumn = if (indices.isEmpty) {
+ child.output.head
+ } else {
+ indices.head
+ }
+
+ // Cast the index column to StringType
+ val indexColumnAsString = inferredIndexColumn match {
+ case attr: Attribute if attr.dataType.isInstanceOf[AtomicType] =>
+ Alias(Cast(attr, StringType), attr.name)()
+ case attr: Attribute =>
+ throw new AnalysisException(
+ errorClass = "TRANSPOSE_INVALID_INDEX_COLUMN",
+ messageParameters = Map(
+ "reason" -> s"Index column must be of atomic type, but found: ${attr.dataType}")
+ )
+ case _ =>
+ throw new AnalysisException(
+ errorClass = "TRANSPOSE_INVALID_INDEX_COLUMN",
+ messageParameters = Map(
+ "reason" -> s"Index column must be an atomic attribute")
+ )
+ }
+
+ // Cast non-index columns to the least common type
+ val nonIndexColumns = child.output.filterNot(
+ _.exprId == inferredIndexColumn.asInstanceOf[Attribute].exprId)
+ val nonIndexTypes = nonIndexColumns.map(_.dataType)
+ val commonType = leastCommonType(nonIndexTypes)
+ val nonIndexColumnsAsLCT = nonIndexColumns.map { attr =>
+ Alias(Cast(attr, commonType), attr.name)()
+ }
+
+ // Exclude nulls and sort index column values, and collect the casted frame
+ val allCastCols = indexColumnAsString +: nonIndexColumnsAsLCT
+ val nonNullChild = Filter(IsNotNull(inferredIndexColumn), child)
+ val sortedChild = Sort(
+ Seq(SortOrder(inferredIndexColumn, Ascending)),
+ global = true,
+ nonNullChild
+ )
+ val projectAllCastCols = Project(allCastCols, sortedChild)
+ val maxValues = sparkSession.sessionState.conf.dataFrameTransposeMaxValues
+ val limit = Literal(maxValues + 1)
+ val limitedProject = Limit(limit, projectAllCastCols)
+ val queryExecution = sparkSession.sessionState.executePlan(limitedProject)
+ val fullCollectedRows = queryExecution.executedPlan.executeCollect()
+
+ if (fullCollectedRows.isEmpty) {
+ // Return a DataFrame with a single column "key" containing non-index column names
+ val keyAttr = AttributeReference("key", StringType, nullable = false)()
+ val keyValues = nonIndexColumns.map(col => UTF8String.fromString(col.name))
+ val keyRows = keyValues.map(value => InternalRow(value))
+
+ Transpose(Seq(keyAttr), keyRows)
+ } else {
+ if (fullCollectedRows.length > maxValues) {
+ throw new AnalysisException(
+ errorClass = "TRANSPOSE_EXCEED_ROW_LIMIT",
+ messageParameters = Map(
+ "maxValues" -> maxValues.toString,
+ "config" -> SQLConf.DATAFRAME_TRANSPOSE_MAX_VALUES.key))
+ }
+
+ // Transpose the matrix
+ val nonIndexColumnNames = nonIndexColumns.map(_.name)
+ val nonIndexColumnDataTypes = projectAllCastCols.output.tail.map(attr => attr.dataType)
+ val transposedMatrix = transposeMatrix(
+ fullCollectedRows, nonIndexColumnNames, nonIndexColumnDataTypes)
+ val transposedInternalRows = transposedMatrix.map { row =>
+ InternalRow.fromSeq(row.toIndexedSeq)
+ }
+
+ // Construct output attributes
+ val keyAttr = AttributeReference("key", StringType, nullable = false)()
+ val transposedColumnNames = fullCollectedRows.map { row => row.getString(0) }
+ val valueAttrs = transposedColumnNames.map { value =>
+ AttributeReference(
+ value,
+ commonType
+ )()
+ }
+
+ val transposeOutput = (keyAttr +: valueAttrs).toIndexedSeq
+ val transposeData = transposedInternalRows.toIndexedSeq
+ Transpose(transposeOutput, transposeData)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala
new file mode 100644
index 0000000000000..af91b57a6848b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.classic
+
+import scala.language.implicitConversions
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql._
+
+/**
+ * Conversions from sql interfaces to the Classic specific implementation.
+ *
+ * This class is mainly used by the implementation, but is also meant to be used by extension
+ * developers.
+ *
+ * We provide both a trait and an object. The trait is useful in situations where an extension
+ * developer needs to use these conversions in a project covering multiple Spark versions. They can
+ * create a shim for these conversions, the Spark 4+ version of the shim implements this trait, and
+ * shims for older versions do not.
+ */
+@DeveloperApi
+trait ClassicConversions {
+ implicit def castToImpl(session: api.SparkSession): SparkSession =
+ session.asInstanceOf[SparkSession]
+
+ implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] =
+ ds.asInstanceOf[Dataset[T]]
+
+ implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset =
+ rgds.asInstanceOf[RelationalGroupedDataset]
+
+ implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V])
+ : KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]
+}
+
+object ClassicConversions extends ClassicConversions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index aae424afcb8ac..1bf6f4e4d7d9f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -474,7 +474,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
// Bucketed scan only has one time overhead but can have multi-times benefits in cache,
// so we always do bucketed scan in a cached plan.
var disableConfigs = Seq(SQLConf.AUTO_BUCKETED_SCAN_ENABLED)
- if (!session.conf.get(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING)) {
+ if (!session.sessionState.conf.getConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING)) {
// Allowing changing cached plan output partitioning might lead to regression as it introduces
// extra shuffle
disableConfigs =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala
index 8a544de7567e8..a0c3d7b51c2c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala
@@ -71,13 +71,15 @@ case class EmptyRelationExec(@transient logical: LogicalPlan) extends LeafExecNo
maxFields,
printNodeId,
indent)
- lastChildren.add(true)
- logical.generateTreeString(
- depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent)
- lastChildren.remove(lastChildren.size() - 1)
+ Option(logical).foreach { _ =>
+ lastChildren.add(true)
+ logical.generateTreeString(
+ depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent)
+ lastChildren.remove(lastChildren.size() - 1)
+ }
}
override def doCanonicalize(): SparkPlan = {
- this.copy(logical = LocalRelation(logical.output).canonicalized)
+ this.copy(logical = LocalRelation(output).canonicalized)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala
new file mode 100644
index 0000000000000..c2b12b053c927
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+case class MultiResultExec(children: Seq[SparkPlan]) extends SparkPlan {
+
+ override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil)
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ children.lastOption.map(_.execute()).getOrElse(sparkContext.emptyRDD)
+ }
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[SparkPlan]): MultiResultExec = {
+ copy(children = newChildren)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 58fff2d4a1a29..5db14a8662138 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -44,7 +44,7 @@ object SQLExecution extends Logging {
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
- private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]()
+ private[sql] val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]()
def getQueryExecution(executionId: Long): QueryExecution = {
executionIdToQueryExecution.get(executionId)
@@ -52,6 +52,9 @@ object SQLExecution extends Logging {
private val testing = sys.props.contains(IS_TESTING.key)
+ private[sql] def executionIdJobTag(session: SparkSession, id: Long) =
+ s"${session.sessionJobTag}-execution-root-id-$id"
+
private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
val sc = sparkSession.sparkContext
// only throw an exception during tests. a missing execution ID should not fail a job.
@@ -82,12 +85,13 @@ object SQLExecution extends Logging {
// And for the root execution, rootExecutionId == executionId.
if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == null) {
sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, executionId.toString)
+ sc.addJobTag(executionIdJobTag(sparkSession, executionId))
}
val rootExecutionId = sc.getLocalProperty(EXECUTION_ROOT_ID_KEY).toLong
executionIdToQueryExecution.put(executionId, queryExecution)
val originalInterruptOnCancel = sc.getLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL)
if (originalInterruptOnCancel == null) {
- val interruptOnCancel = sparkSession.conf.get(SQLConf.INTERRUPT_ON_CANCEL)
+ val interruptOnCancel = sparkSession.sessionState.conf.getConf(SQLConf.INTERRUPT_ON_CANCEL)
sc.setInterruptOnCancel(interruptOnCancel)
}
try {
@@ -116,92 +120,94 @@ object SQLExecution extends Logging {
val redactedConfigs = sparkSession.sessionState.conf.redactOptions(modifiedConfigs)
withSQLConfPropagated(sparkSession) {
- var ex: Option[Throwable] = None
- var isExecutedPlanAvailable = false
- val startTime = System.nanoTime()
- val startEvent = SparkListenerSQLExecutionStart(
- executionId = executionId,
- rootExecutionId = Some(rootExecutionId),
- description = desc,
- details = callSite.longForm,
- physicalPlanDescription = "",
- sparkPlanInfo = SparkPlanInfo.EMPTY,
- time = System.currentTimeMillis(),
- modifiedConfigs = redactedConfigs,
- jobTags = sc.getJobTags(),
- jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
- )
- try {
- body match {
- case Left(e) =>
- sc.listenerBus.post(startEvent)
+ withSessionTagsApplied(sparkSession) {
+ var ex: Option[Throwable] = None
+ var isExecutedPlanAvailable = false
+ val startTime = System.nanoTime()
+ val startEvent = SparkListenerSQLExecutionStart(
+ executionId = executionId,
+ rootExecutionId = Some(rootExecutionId),
+ description = desc,
+ details = callSite.longForm,
+ physicalPlanDescription = "",
+ sparkPlanInfo = SparkPlanInfo.EMPTY,
+ time = System.currentTimeMillis(),
+ modifiedConfigs = redactedConfigs,
+ jobTags = sc.getJobTags(),
+ jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID))
+ )
+ try {
+ body match {
+ case Left(e) =>
+ sc.listenerBus.post(startEvent)
+ throw e
+ case Right(f) =>
+ val planDescriptionMode =
+ ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
+ val planDesc = queryExecution.explainString(planDescriptionMode)
+ val planInfo = try {
+ SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
+ } catch {
+ case NonFatal(e) =>
+ logDebug("Failed to generate SparkPlanInfo", e)
+ // If the queryExecution already failed before this, we are not able to generate
+ // the the plan info, so we use and empty graphviz node to make the UI happy
+ SparkPlanInfo.EMPTY
+ }
+ sc.listenerBus.post(
+ startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo))
+ isExecutedPlanAvailable = true
+ f()
+ }
+ } catch {
+ case e: Throwable =>
+ ex = Some(e)
throw e
- case Right(f) =>
- val planDescriptionMode =
- ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode)
- val planDesc = queryExecution.explainString(planDescriptionMode)
- val planInfo = try {
- SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan)
- } catch {
- case NonFatal(e) =>
- logDebug("Failed to generate SparkPlanInfo", e)
- // If the queryExecution already failed before this, we are not able to generate
- // the the plan info, so we use and empty graphviz node to make the UI happy
- SparkPlanInfo.EMPTY
- }
- sc.listenerBus.post(
- startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo))
- isExecutedPlanAvailable = true
- f()
- }
- } catch {
- case e: Throwable =>
- ex = Some(e)
- throw e
- } finally {
- val endTime = System.nanoTime()
- val errorMessage = ex.map {
- case e: SparkThrowable =>
- SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
- case e =>
- Utils.exceptionString(e)
- }
- if (queryExecution.shuffleCleanupMode != DoNotCleanup
- && isExecutedPlanAvailable) {
- val shuffleIds = queryExecution.executedPlan match {
- case ae: AdaptiveSparkPlanExec =>
- ae.context.shuffleIds.asScala.keys
- case _ =>
- Iterable.empty
+ } finally {
+ val endTime = System.nanoTime()
+ val errorMessage = ex.map {
+ case e: SparkThrowable =>
+ SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY)
+ case e =>
+ Utils.exceptionString(e)
}
- shuffleIds.foreach { shuffleId =>
- queryExecution.shuffleCleanupMode match {
- case RemoveShuffleFiles =>
- // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister
- // the shuffle on MapOutputTracker, so that stage retries would be triggered.
- // Set blocking to Utils.isTesting to deflake unit tests.
- sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting)
- case SkipMigration =>
- SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
- case _ => // this should not happen
+ if (queryExecution.shuffleCleanupMode != DoNotCleanup
+ && isExecutedPlanAvailable) {
+ val shuffleIds = queryExecution.executedPlan match {
+ case ae: AdaptiveSparkPlanExec =>
+ ae.context.shuffleIds.asScala.keys
+ case _ =>
+ Iterable.empty
+ }
+ shuffleIds.foreach { shuffleId =>
+ queryExecution.shuffleCleanupMode match {
+ case RemoveShuffleFiles =>
+ // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister
+ // the shuffle on MapOutputTracker, so that stage retries would be triggered.
+ // Set blocking to Utils.isTesting to deflake unit tests.
+ sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting)
+ case SkipMigration =>
+ SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId)
+ case _ => // this should not happen
+ }
}
}
+ val event = SparkListenerSQLExecutionEnd(
+ executionId,
+ System.currentTimeMillis(),
+ // Use empty string to indicate no error, as None may mean events generated by old
+ // versions of Spark.
+ errorMessage.orElse(Some("")))
+ // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the
+ // `name` parameter. The `ExecutionListenerManager` only watches SQL executions with
+ // name. We can specify the execution name in more places in the future, so that
+ // `QueryExecutionListener` can track more cases.
+ event.executionName = name
+ event.duration = endTime - startTime
+ event.qe = queryExecution
+ event.executionFailure = ex
+ sc.listenerBus.post(event)
}
- val event = SparkListenerSQLExecutionEnd(
- executionId,
- System.currentTimeMillis(),
- // Use empty string to indicate no error, as None may mean events generated by old
- // versions of Spark.
- errorMessage.orElse(Some("")))
- // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name`
- // parameter. The `ExecutionListenerManager` only watches SQL executions with name. We
- // can specify the execution name in more places in the future, so that
- // `QueryExecutionListener` can track more cases.
- event.executionName = name
- event.duration = endTime - startTime
- event.qe = queryExecution
- event.executionFailure = ex
- sc.listenerBus.post(event)
}
}
} finally {
@@ -211,6 +217,7 @@ object SQLExecution extends Logging {
// The current execution is the root execution if rootExecutionId == executionId.
if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == executionId.toString) {
sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, null)
+ sc.removeJobTag(executionIdJobTag(sparkSession, executionId))
}
sc.setLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL, originalInterruptOnCancel)
}
@@ -238,15 +245,28 @@ object SQLExecution extends Logging {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
withSQLConfPropagated(sparkSession) {
- try {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
- body
- } finally {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ withSessionTagsApplied(sparkSession) {
+ try {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
+ body
+ } finally {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ }
}
}
}
+ private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = {
+ val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag
+ sparkSession.sparkContext.addJobTags(allTags)
+
+ try {
+ block
+ } finally {
+ sparkSession.sparkContext.removeJobTags(allTags)
+ }
+ }
+
/**
* Wrap an action with specified SQL configs. These configs will be propagated to the executor
* side via job local properties.
@@ -286,10 +306,13 @@ object SQLExecution extends Logging {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
SparkSession.setActiveSession(activeSession)
- sc.setLocalProperties(localProps)
- val res = body
- // reset active session and local props.
- sc.setLocalProperties(originalLocalProps)
+ val res = withSessionTagsApplied(activeSession) {
+ sc.setLocalProperties(localProps)
+ val res = body
+ // reset active session and local props.
+ sc.setLocalProperties(originalLocalProps)
+ res
+ }
if (originalSession.nonEmpty) {
SparkSession.setActiveSession(originalSession.get)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 8f27a0e8f673d..a8261e5d98ba0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.trees.TreePattern.PARAMETER
import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryParsingErrors}
import org.apache.spark.sql.execution.command._
@@ -466,22 +465,6 @@ class SparkSqlAstBuilder extends AstBuilder {
}
}
-
- private def checkInvalidParameter(plan: LogicalPlan, statement: String):
- Unit = {
- plan.foreach { p =>
- p.expressions.foreach { expr =>
- if (expr.containsPattern(PARAMETER)) {
- throw QueryParsingErrors.parameterMarkerNotAllowed(statement, p.origin)
- }
- }
- }
- plan.children.foreach(p => checkInvalidParameter(p, statement))
- plan.innerChildren.collect {
- case child: LogicalPlan => checkInvalidParameter(child, statement)
- }
- }
-
/**
* Create or replace a view. This creates a [[CreateViewCommand]].
*
@@ -537,12 +520,13 @@ class SparkSqlAstBuilder extends AstBuilder {
}
val qPlan: LogicalPlan = plan(ctx.query)
- // Disallow parameter markers in the body of the view.
+ // Disallow parameter markers in the query of the view.
// We need this limitation because we store the original query text, pre substitution.
- // To lift this we would need to reconstitute the body with parameter markers replaced with the
+ // To lift this we would need to reconstitute the query with parameter markers replaced with the
// values given at CREATE VIEW time, or we would need to store the parameter values alongside
// the text.
- checkInvalidParameter(qPlan, "CREATE VIEW body")
+ // The same rule can be found in CACHE TABLE builder.
+ checkInvalidParameter(qPlan, "the query of CREATE VIEW")
if (viewType == PersistedView) {
val originalText = source(ctx.query)
assert(Option(originalText).isDefined,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6d940a30619fb..aee735e48fc5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -1041,6 +1041,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) =>
WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options,
staticPartitions) :: Nil
+ case MultiResult(children) =>
+ MultiResultExec(children.map(planLater)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 3832d73044078..09d9915022a65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -504,7 +504,7 @@ case class ScalaAggregator[IN, BUF, OUT](
private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer()
- private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
+ private[this] lazy val outputEncoder = encoderFor(agg.outputEncoder)
private[this] lazy val outputSerializer = outputEncoder.createSerializer()
def dataType: DataType = outputEncoder.objSerializer.dataType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
index 5a9adf8ab553d..91454c79df600 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
@@ -330,7 +330,7 @@ object CommandUtils extends Logging {
val attributePercentiles = mutable.HashMap[Attribute, ArrayData]()
if (attrsToGenHistogram.nonEmpty) {
val percentiles = (0 to conf.histogramNumBins)
- .map(i => i.toDouble / conf.histogramNumBins).toArray
+ .map(i => i.toDouble / conf.histogramNumBins).toArray[Any]
val namedExprs = attrsToGenHistogram.map { attr =>
val aggFunc =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index ea2736b2c1266..ea9d53190546e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, SupervisingCommand}
+import org.apache.spark.sql.catalyst.plans.logical.{Command, ExecutableDuringAnalysis, LogicalPlan, SupervisingCommand}
import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
@@ -165,14 +165,19 @@ case class ExplainCommand(
// Run through the optimizer to generate the physical plan.
override def run(sparkSession: SparkSession): Seq[Row] = try {
- val outputString = sparkSession.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP)
- .explainString(mode)
+ val stagedLogicalPlan = stageForAnalysis(logicalPlan)
+ val qe = sparkSession.sessionState.executePlan(stagedLogicalPlan, CommandExecutionMode.SKIP)
+ val outputString = qe.explainString(mode)
Seq(Row(outputString))
} catch { case NonFatal(cause) =>
("Error occurred during query planning: \n" + cause.getMessage).split("\n")
.map(Row(_)).toImmutableArraySeq
}
+ private def stageForAnalysis(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p: ExecutableDuringAnalysis => p.stageForExplain()
+ }
+
def withTransformedSupervisedPlan(transformer: LogicalPlan => LogicalPlan): LogicalPlan =
copy(logicalPlan = transformer(logicalPlan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 3f221bfa53051..814e56b204f9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -861,7 +861,7 @@ case class RepairTableCommand(
// Hive metastore may not have enough memory to handle millions of partitions in single RPC,
// we should split them into smaller batches. Since Hive client is not thread safe, we cannot
// do this in parallel.
- val batchSize = spark.conf.get(SQLConf.ADD_PARTITION_BATCH_SIZE)
+ val batchSize = spark.sessionState.conf.getConf(SQLConf.ADD_PARTITION_BATCH_SIZE)
partitionSpecsAndLocs.iterator.grouped(batchSize).foreach { batch =>
val now = MILLISECONDS.toSeconds(System.currentTimeMillis())
val parts = batch.map { case (spec, location) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index e1061a46db7b0..071e3826b20a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -135,7 +135,7 @@ case class CreateViewCommand(
referredTempFunctions)
catalog.createTempView(name.table, tableDefinition, overrideIfExists = replace)
} else if (viewType == GlobalTempView) {
- val db = sparkSession.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
+ val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
val viewIdent = TableIdentifier(name.table, Option(db))
val aliasedPlan = aliasPlan(sparkSession, analyzedPlan)
val tableDefinition = createTemporaryViewRelation(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
index 4fa1e0c1f2c58..fd47feef25d57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.types.{DataType, StringType}
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.util.SchemaUtils
object BucketingUtils {
// The file name of bucketed data should have 3 parts:
@@ -53,10 +54,7 @@ object BucketingUtils {
bucketIdGenerator(mutableInternalRow).getInt(0)
}
- def canBucketOn(dataType: DataType): Boolean = dataType match {
- case st: StringType => st.supportsBinaryOrdering
- case other => true
- }
+ def canBucketOn(dataType: DataType): Boolean = !SchemaUtils.hasNonUTF8BinaryCollation(dataType)
def bucketIdToString(id: Int): String = f"_$id%05d"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index d88b5ee8877d7..968c204841e46 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -267,7 +267,8 @@ case class DataSource(
checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false)
createInMemoryFileIndex(globbedPaths)
})
- val forceNullable = sparkSession.conf.get(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE)
+ val forceNullable = sparkSession.sessionState.conf
+ .getConf(SQLConf.FILE_SOURCE_SCHEMA_FORCE_NULLABLE)
val sourceDataSchema = if (forceNullable) dataSchema.asNullable else dataSchema
SourceInfo(
s"FileSource[$path]",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 1dd2659a1b169..2be4b236872f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoDir, I
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder}
+import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder}
import org.apache.spark.sql.connector.catalog.{SupportsRead, V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
@@ -146,6 +146,11 @@ object DataSourceAnalysis extends Rule[LogicalPlan] {
tableDesc.identifier, "generated columns")
}
+ if (IdentityColumn.hasIdentityColumns(newSchema)) {
+ throw QueryCompilationErrors.unsupportedTableOperationError(
+ tableDesc.identifier, "identity columns")
+ }
+
val newTableDesc = tableDesc.copy(schema = newSchema)
CreateDataSourceTableCommand(newTableDesc, ignoreIfExists = mode == SaveMode.Ignore)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 676a2ab64d0a3..ffdca65151052 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -29,7 +29,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.fs.Path
-import org.apache.spark.SparkRuntimeException
+import org.apache.spark.{SparkException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -550,7 +550,7 @@ object PartitioningUtils extends SQLConfHelper {
Cast(Literal(unescapePathName(value)), it).eval()
case BinaryType => value.getBytes()
case BooleanType => value.toBoolean
- case dt => throw QueryExecutionErrors.typeUnsupportedError(dt)
+ case dt => throw SparkException.internalError(s"Unsupported partition type: $dt")
}
def validatePartitionColumn(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
index 5423232db4293..e44f1d35e9cdf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.datasources
import scala.util.control.NonFatal
+import org.apache.spark.SparkThrowable
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, CTERelationDef, LogicalPlan, WithCTE}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command.LeafRunnableCommand
-import org.apache.spark.sql.sources.CreatableRelationProvider
+import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider}
/**
* Saves the results of `query` in to a data source.
@@ -44,8 +46,26 @@ case class SaveIntoDataSourceCommand(
override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override def run(sparkSession: SparkSession): Seq[Row] = {
- val relation = dataSource.createRelation(
- sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))
+ var relation: BaseRelation = null
+
+ try {
+ relation = dataSource.createRelation(
+ sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))
+ } catch {
+ case e: SparkThrowable =>
+ // We should avoid wrapping `SparkThrowable` exceptions into another `AnalysisException`.
+ throw e
+ case e @ (_: NullPointerException | _: MatchError | _: ArrayIndexOutOfBoundsException) =>
+ // These are some of the exceptions thrown by the data source API. We catch these
+ // exceptions here and rethrow QueryCompilationErrors.externalDataSourceException to
+ // provide a more friendly error message for the user. This list is not exhaustive.
+ throw QueryCompilationErrors.externalDataSourceException(e)
+ case e: Throwable =>
+ // For other exceptions, just rethrow it, since we don't have enough information to
+ // provide a better error message for the user at the moment. We may want to further
+ // improve the error message handling in the future.
+ throw e
+ }
try {
val logicalRelation = LogicalRelation(relation, toAttributes(relation.schema), None, false)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
index cbff526592f92..54c100282e2db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
@@ -98,7 +98,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister {
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val filterFuncs = filters.flatMap(filter => createFilterFunction(filter))
- val maxLength = sparkSession.conf.get(SOURCES_BINARY_FILE_MAX_LENGTH)
+ val maxLength = sparkSession.sessionState.conf.getConf(SOURCES_BINARY_FILE_MAX_LENGTH)
file: PartitionedFile => {
val path = file.toPath
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index fc6cba786c4ed..d9367d92d462e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -115,7 +115,7 @@ case class CreateTempViewUsing(
}.logicalPlan
if (global) {
- val db = sparkSession.conf.get(StaticSQLConf.GLOBAL_TEMP_DATABASE)
+ val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
val viewIdent = TableIdentifier(tableIdent.table, Option(db))
val viewDefinition = createTemporaryViewRelation(
viewIdent,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 112ee2c5450b2..76cd33b815edd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder}
+import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder}
import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
@@ -185,6 +185,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val statementType = "CREATE TABLE"
GeneratedColumn.validateGeneratedColumns(
c.tableSchema, catalog.asTableCatalog, ident, statementType)
+ IdentityColumn.validateIdentityColumn(c.tableSchema, catalog.asTableCatalog, ident)
CreateTableExec(
catalog.asTableCatalog,
@@ -214,6 +215,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val statementType = "REPLACE TABLE"
GeneratedColumn.validateGeneratedColumns(
c.tableSchema, catalog.asTableCatalog, ident, statementType)
+ IdentityColumn.validateIdentityColumn(c.tableSchema, catalog.asTableCatalog, ident)
val v2Columns = columns.map(_.toV2Column(statementType)).toArray
catalog match {
@@ -551,6 +553,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
systemScope,
pattern) :: Nil
+ case c: Call =>
+ ExplainOnlySparkPlan(c) :: Nil
+
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala
new file mode 100644
index 0000000000000..bbf56eaa71184
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.LeafLike
+import org.apache.spark.sql.execution.SparkPlan
+
+case class ExplainOnlySparkPlan(toExplain: LogicalPlan) extends SparkPlan with LeafLike[SparkPlan] {
+
+ override def output: Seq[Attribute] = Nil
+
+ override def simpleString(maxFields: Int): String = {
+ toExplain.simpleString(maxFields)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
index f56f9436d9437..4eee731e0b2d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
@@ -146,6 +146,19 @@ abstract class FileTable(
val entry = options.get(DataSource.GLOB_PATHS_KEY)
Option(entry).map(_ == "true").getOrElse(true)
}
+
+ /**
+ * Merge the options of FileTable and the table operation while respecting the
+ * keys of the table operation.
+ *
+ * @param options The options of the table operation.
+ * @return
+ */
+ protected def mergedOptions(options: CaseInsensitiveStringMap): CaseInsensitiveStringMap = {
+ val finalOptions = this.options.asCaseSensitiveMap().asScala ++
+ options.asCaseSensitiveMap().asScala
+ new CaseInsensitiveStringMap(finalOptions.asJava)
+ }
}
object FileTable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
index 8b4fd3af6ded7..4c201ca66cf6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
@@ -38,7 +38,7 @@ case class CSVTable(
fallbackFileFormat: Class[_ <: FileFormat])
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
override def newScanBuilder(options: CaseInsensitiveStringMap): CSVScanBuilder =
- CSVScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
+ CSVScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options))
override def inferSchema(files: Seq[FileStatus]): Option[StructType] = {
val parsedOptions = new CSVOptions(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala
index c567e87e7d767..54244c4d95e77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala
@@ -38,7 +38,7 @@ case class JsonTable(
fallbackFileFormat: Class[_ <: FileFormat])
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
override def newScanBuilder(options: CaseInsensitiveStringMap): JsonScanBuilder =
- new JsonScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
+ JsonScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options))
override def inferSchema(files: Seq[FileStatus]): Option[StructType] = {
val parsedOptions = new JSONOptionsInRead(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
index ca4b83b3c58f1..1037370967c87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
@@ -38,7 +38,7 @@ case class OrcTable(
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
override def newScanBuilder(options: CaseInsensitiveStringMap): OrcScanBuilder =
- new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
+ OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options))
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala
index e593ad7d0c0cd..8463a05569c05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala
@@ -38,7 +38,7 @@ case class ParquetTable(
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
override def newScanBuilder(options: CaseInsensitiveStringMap): ParquetScanBuilder =
- new ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
+ ParquetScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options))
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 83399e2cac01b..50b90641d309b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{RuntimeConfig, SparkSession}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
@@ -119,9 +119,9 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging
throw StateDataSourceErrors.offsetMetadataLogUnavailable(batchId, checkpointLocation)
)
- val clonedRuntimeConf = new RuntimeConfig(session.sessionState.conf.clone())
- OffsetSeqMetadata.setSessionConf(metadata, clonedRuntimeConf)
- StateStoreConf(clonedRuntimeConf.sqlConf)
+ val clonedSqlConf = session.sessionState.conf.clone()
+ OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf)
+ StateStoreConf(clonedSqlConf)
case _ =>
throw StateDataSourceErrors.offsetLogUnavailable(batchId, checkpointLocation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 53576c335cb01..24166a46bbd39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2.state
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString, RecordType}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{NullType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{NextIterator, SerializableConfiguration}
@@ -68,10 +69,23 @@ abstract class StatePartitionReaderBase(
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
extends PartitionReader[InternalRow] with Logging {
- protected val keySchema = SchemaUtil.getSchemaAsDataType(
- schema, "key").asInstanceOf[StructType]
- protected val valueSchema = SchemaUtil.getSchemaAsDataType(
- schema, "value").asInstanceOf[StructType]
+ // Used primarily as a placeholder for the value schema in the context of
+ // state variables used within the transformWithState operator.
+ private val schemaForValueRow: StructType =
+ StructType(Array(StructField("__dummy__", NullType)))
+
+ protected val keySchema = {
+ if (!SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) {
+ SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
+ } else SchemaUtil.getCompositeKeySchema(schema)
+ }
+
+ protected val valueSchema = if (stateVariableInfoOpt.isDefined) {
+ schemaForValueRow
+ } else {
+ SchemaUtil.getSchemaAsDataType(
+ schema, "value").asInstanceOf[StructType]
+ }
protected lazy val provider: StateStoreProvider = {
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
@@ -84,10 +98,17 @@ abstract class StatePartitionReaderBase(
false
}
+ val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined &&
+ stateVariableInfoOpt.get.stateVariableType == StateVariableType.ListState) {
+ true
+ } else {
+ false
+ }
+
val provider = StateStoreProvider.createAndInit(
stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
- useMultipleValuesPerKey = false)
+ useMultipleValuesPerKey = useMultipleValuesPerKey)
if (useColFamilies) {
val store = provider.getStore(partition.sourceOptions.batchId + 1)
@@ -99,7 +120,7 @@ abstract class StatePartitionReaderBase(
stateStoreColFamilySchema.keySchema,
stateStoreColFamilySchema.valueSchema,
stateStoreColFamilySchema.keyStateEncoderSpec.get,
- useMultipleValuesPerKey = false)
+ useMultipleValuesPerKey = useMultipleValuesPerKey)
}
provider
}
@@ -160,32 +181,43 @@ class StatePartitionReader(
override lazy val iter: Iterator[InternalRow] = {
val stateVarName = stateVariableInfoOpt
.map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
- store
- .iterator(stateVarName)
- .map { pair =>
- stateVariableInfoOpt match {
- case Some(stateVarInfo) =>
- val stateVarType = stateVarInfo.stateVariableType
- val hasTTLEnabled = stateVarInfo.ttlEnabled
-
- stateVarType match {
- case StateVariableType.ValueState =>
- if (hasTTLEnabled) {
- SchemaUtil.unifyStateRowPairWithTTL((pair.key, pair.value), valueSchema,
- partition.partition)
- } else {
+ if (SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) {
+ SchemaUtil.unifyMapStateRowPair(
+ store.iterator(stateVarName), keySchema, partition.partition)
+ } else {
+ store
+ .iterator(stateVarName)
+ .map { pair =>
+ stateVariableInfoOpt match {
+ case Some(stateVarInfo) =>
+ val stateVarType = stateVarInfo.stateVariableType
+
+ stateVarType match {
+ case StateVariableType.ValueState =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
- }
- case _ =>
- throw new IllegalStateException(
- s"Unsupported state variable type: $stateVarType")
- }
+ case StateVariableType.ListState =>
+ val key = pair.key
+ val result = store.valuesIterator(key, stateVarName)
+ var unsafeRowArr: Seq[UnsafeRow] = Seq.empty
+ result.foreach { entry =>
+ unsafeRowArr = unsafeRowArr :+ entry.copy()
+ }
+ // convert the list of values to array type
+ val arrData = new GenericArrayData(unsafeRowArr.toArray)
+ SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData),
+ partition.partition)
+
+ case _ =>
+ throw new IllegalStateException(
+ s"Unsupported state variable type: $stateVarType")
+ }
- case None =>
- SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
+ case None =>
+ SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
+ }
}
- }
+ }
}
override def close(): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index 9dd357530ec40..88ea06d598e56 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -16,13 +16,18 @@
*/
package org.apache.spark.sql.execution.datasources.v2.state.utils
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions}
-import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo}
-import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
+import org.apache.spark.sql.execution.streaming.StateVariableType._
+import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo
+import org.apache.spark.sql.execution.streaming.state.{StateStoreColFamilySchema, UnsafeRowPair}
+import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType}
import org.apache.spark.util.ArrayImplicits._
object SchemaUtil {
@@ -70,18 +75,122 @@ object SchemaUtil {
row
}
- def unifyStateRowPairWithTTL(
- pair: (UnsafeRow, UnsafeRow),
- valueSchema: StructType,
+ def unifyStateRowPairWithMultipleValues(
+ pair: (UnsafeRow, GenericArrayData),
partition: Int): InternalRow = {
- val row = new GenericInternalRow(4)
+ val row = new GenericInternalRow(3)
row.update(0, pair._1)
- row.update(1, pair._2.get(0, valueSchema))
- row.update(2, pair._2.get(1, LongType))
- row.update(3, partition)
+ row.update(1, pair._2)
+ row.update(2, partition)
row
}
+ /**
+ * For map state variables, state rows are stored as composite key.
+ * To return grouping key -> Map{user key -> value} as one state reader row to
+ * the users, we need to perform grouping on state rows by their grouping key,
+ * and construct a map for that grouping key.
+ *
+ * We traverse the iterator returned from state store,
+ * and will only return a row for `next()` only if the grouping key in the next row
+ * from state store is different (or there are no more rows)
+ *
+ * Note that all state rows with the same grouping key are co-located so they will
+ * appear consecutively during the iterator traversal.
+ */
+ def unifyMapStateRowPair(
+ stateRows: Iterator[UnsafeRowPair],
+ compositeKeySchema: StructType,
+ partitionId: Int): Iterator[InternalRow] = {
+ val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
+ compositeKeySchema, "key"
+ ).asInstanceOf[StructType]
+ val userKeySchema = SchemaUtil.getSchemaAsDataType(
+ compositeKeySchema, "userKey"
+ ).asInstanceOf[StructType]
+
+ def appendKVPairToMap(
+ curMap: mutable.Map[Any, Any],
+ stateRowPair: UnsafeRowPair): Unit = {
+ curMap += (
+ stateRowPair.key.get(1, userKeySchema)
+ .asInstanceOf[UnsafeRow].copy() ->
+ stateRowPair.value.copy()
+ )
+ }
+
+ def createDataRow(
+ groupingKey: Any,
+ curMap: mutable.Map[Any, Any]): GenericInternalRow = {
+ val row = new GenericInternalRow(3)
+ val mapData = ArrayBasedMapData(curMap)
+ row.update(0, groupingKey)
+ row.update(1, mapData)
+ row.update(2, partitionId)
+ row
+ }
+
+ // All of the rows with the same grouping key were co-located and were
+ // grouped together consecutively.
+ new Iterator[InternalRow] {
+ var curGroupingKey: UnsafeRow = _
+ var curStateRowPair: UnsafeRowPair = _
+ val curMap = mutable.Map.empty[Any, Any]
+
+ override def hasNext: Boolean =
+ stateRows.hasNext || !curMap.isEmpty
+
+ override def next(): InternalRow = {
+ var foundNewGroupingKey = false
+ while (stateRows.hasNext && !foundNewGroupingKey) {
+ curStateRowPair = stateRows.next()
+ if (curGroupingKey == null) {
+ // First time in the iterator
+ // Need to make a copy because we need to keep the
+ // value across function calls
+ curGroupingKey = curStateRowPair.key
+ .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy()
+ appendKVPairToMap(curMap, curStateRowPair)
+ } else {
+ val curPairGroupingKey =
+ curStateRowPair.key.get(0, groupingKeySchema)
+ if (curPairGroupingKey == curGroupingKey) {
+ appendKVPairToMap(curMap, curStateRowPair)
+ } else {
+ // find a different grouping key, exit loop and return a row
+ foundNewGroupingKey = true
+ }
+ }
+ }
+ if (foundNewGroupingKey) {
+ // found a different grouping key
+ val row = createDataRow(curGroupingKey, curMap)
+ // update vars
+ curGroupingKey =
+ curStateRowPair.key.get(0, groupingKeySchema)
+ .asInstanceOf[UnsafeRow].copy()
+ // empty the map, append current row
+ curMap.clear()
+ appendKVPairToMap(curMap, curStateRowPair)
+ // return map value of previous grouping key
+ row
+ } else {
+ if (curMap.isEmpty) {
+ throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " +
+ "user is trying to get element from an exhausted iterator.")
+ }
+ else {
+ // reach the end of the state rows
+ val row = createDataRow(curGroupingKey, curMap)
+ // clear the map to end the iterator
+ curMap.clear()
+ row
+ }
+ }
+ }
+ }
+ }
+
def isValidSchema(
sourceOptions: StateSourceOptions,
schema: StructType,
@@ -91,23 +200,26 @@ object SchemaUtil {
"change_type" -> classOf[StringType],
"key" -> classOf[StructType],
"value" -> classOf[StructType],
- "partition_id" -> classOf[IntegerType],
- "expiration_timestamp" -> classOf[LongType])
+ "single_value" -> classOf[StructType],
+ "list_value" -> classOf[ArrayType],
+ "map_value" -> classOf[MapType],
+ "partition_id" -> classOf[IntegerType])
val expectedFieldNames = if (sourceOptions.readChangeFeed) {
Seq("batch_id", "change_type", "key", "value", "partition_id")
} else if (transformWithStateVariableInfoOpt.isDefined) {
val stateVarInfo = transformWithStateVariableInfoOpt.get
- val hasTTLEnabled = stateVarInfo.ttlEnabled
val stateVarType = stateVarInfo.stateVariableType
stateVarType match {
- case StateVariableType.ValueState =>
- if (hasTTLEnabled) {
- Seq("key", "value", "expiration_timestamp", "partition_id")
- } else {
- Seq("key", "value", "partition_id")
- }
+ case ValueState =>
+ Seq("key", "single_value", "partition_id")
+
+ case ListState =>
+ Seq("key", "list_value", "partition_id")
+
+ case MapState =>
+ Seq("key", "map_value", "partition_id")
case _ =>
throw StateDataSourceErrors
@@ -131,27 +243,73 @@ object SchemaUtil {
stateVarInfo: TransformWithStateVariableInfo,
stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = {
val stateVarType = stateVarInfo.stateVariableType
- val hasTTLEnabled = stateVarInfo.ttlEnabled
stateVarType match {
- case StateVariableType.ValueState =>
- if (hasTTLEnabled) {
- val ttlValueSchema = SchemaUtil.getSchemaAsDataType(
- stateStoreColFamilySchema.valueSchema, "value").asInstanceOf[StructType]
- new StructType()
- .add("key", stateStoreColFamilySchema.keySchema)
- .add("value", ttlValueSchema)
- .add("expiration_timestamp", LongType)
- .add("partition_id", IntegerType)
- } else {
- new StructType()
- .add("key", stateStoreColFamilySchema.keySchema)
- .add("value", stateStoreColFamilySchema.valueSchema)
- .add("partition_id", IntegerType)
- }
+ case ValueState =>
+ new StructType()
+ .add("key", stateStoreColFamilySchema.keySchema)
+ .add("single_value", stateStoreColFamilySchema.valueSchema)
+ .add("partition_id", IntegerType)
+
+ case ListState =>
+ new StructType()
+ .add("key", stateStoreColFamilySchema.keySchema)
+ .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema))
+ .add("partition_id", IntegerType)
+
+ case MapState =>
+ val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
+ stateStoreColFamilySchema.keySchema, "key")
+ val userKeySchema = stateStoreColFamilySchema.userKeyEncoderSchema.get
+ val valueMapSchema = MapType.apply(
+ keyType = userKeySchema,
+ valueType = stateStoreColFamilySchema.valueSchema
+ )
+
+ new StructType()
+ .add("key", groupingKeySchema)
+ .add("map_value", valueMapSchema)
+ .add("partition_id", IntegerType)
case _ =>
throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType")
}
}
+
+ /**
+ * Helper functions for map state data source reader.
+ *
+ * Map state variables are stored in RocksDB state store has the schema of
+ * `TransformWithStateKeyValueRowSchemaUtils.getCompositeKeySchema()`;
+ * But for state store reader, we need to return in format of:
+ * "key": groupingKey, "map_value": Map(userKey -> value).
+ *
+ * The following functions help to translate between two schema.
+ */
+ def isMapStateVariable(
+ stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = {
+ stateVariableInfoOpt.isDefined &&
+ stateVariableInfoOpt.get.stateVariableType == MapState
+ }
+
+ /**
+ * Given key-value schema generated from `generateSchemaForStateVar()`,
+ * returns the compositeKey schema that key is stored in the state store
+ */
+ def getCompositeKeySchema(schema: StructType): StructType = {
+ val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
+ schema, "key").asInstanceOf[StructType]
+ val userKeySchema = try {
+ Option(
+ SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType]
+ .keyType.asInstanceOf[StructType])
+ } catch {
+ case NonFatal(e) =>
+ throw StateDataSourceErrors.internalError(s"No such field named as 'map_value' " +
+ s"during state source reader schema initialization. Internal exception message: $e")
+ }
+ new StructType()
+ .add("key", groupingKeySchema)
+ .add("userKey", userKeySchema.get)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala
index 046bdcb69846e..87ae34532f88a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala
@@ -34,7 +34,7 @@ case class TextTable(
fallbackFileFormat: Class[_ <: FileFormat])
extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
override def newScanBuilder(options: CaseInsensitiveStringMap): TextScanBuilder =
- TextScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
+ TextScanBuilder(sparkSession, fileIndex, schema, dataSchema, mergedOptions(options))
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
Some(StructType(Array(StructField("value", StringType))))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala
index e91140414732b..a2d200dc86e18 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala
@@ -23,13 +23,13 @@ import org.apache.spark.sql.execution.SparkPlan
/**
- * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]]
+ * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInArrow]]
*
* The input dataframes are first Cogrouped. Rows from each side of the cogroup are passed to the
* Python worker via Arrow. As each side of the cogroup may have a different schema we send every
* group in its own Arrow stream.
- * The Python worker turns the resulting record batches to `pandas.DataFrame`s, invokes the
- * user-defined function, and passes the resulting `pandas.DataFrame`
+ * The Python worker turns the resulting record batches to `pyarrow.Table`s, invokes the
+ * user-defined function, and passes the resulting `pyarrow.Table`
* as an Arrow record batch. Finally, each record batch is turned to
* Iterator[InternalRow] using ColumnarBatch.
*
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.SparkPlan
* Both the Python worker and the Java executor need to have enough memory to
* hold the largest cogroup. The memory on the Java side is used to construct the
* record batches (off heap memory). The memory on the Python side is used for
- * holding the `pandas.DataFrame`. It's possible to further split one group into
+ * holding the `pyarrow.Table`. It's possible to further split one group into
* multiple record batches to reduce the memory footprint on the Java side, this
* is left as future work.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
index 942aaf6e44c17..6569b29f3954f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala
@@ -25,11 +25,11 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
- * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
+ * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInArrow]]
*
* Rows in each group are passed to the Python worker as an Arrow record batch.
- * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the
- * user-defined function, and passes the resulting `pandas.DataFrame`
+ * The Python worker turns the record batch to a `pyarrow.Table`, invokes the
+ * user-defined function, and passes the resulting `pyarrow.Table`
* as an Arrow record batch. Finally, each record batch is turned to
* Iterator[InternalRow] using ColumnarBatch.
*
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
* Both the Python worker and the Java executor need to have enough memory to
* hold the largest group. The memory on the Java side is used to construct the
* record batch (off heap memory). The memory on the Python side is used for
- * holding the `pandas.DataFrame`. It's possible to further split one group into
+ * holding the `pyarrow.Table`. It's possible to further split one group into
* multiple record batches to reduce the memory footprint on the Java side, this
* is left as future work.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
index 67b264436fea9..ed7ff6a753487 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala
@@ -21,12 +21,15 @@ import java.io.File
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
-import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
+import scala.util.control.NonFatal
+
+import org.apache.spark.{JobArtifactSet, SparkEnv, SparkThrowable, TaskContext}
import org.apache.spark.api.python._
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.sources.ForeachUserFuncException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{NextIterator, Utils}
@@ -53,6 +56,8 @@ class WriterThread(outputIterator: Iterator[Array[Byte]])
} catch {
// Cache exceptions seen while evaluating the Python function on the streamed input. The
// parent thread will throw this crashed exception eventually.
+ case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
+ _exception = ForeachUserFuncException(e)
case t: Throwable =>
_exception = t
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
index 33612b6947f27..11ab706e8abb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala
@@ -233,7 +233,17 @@ class PythonStreamingSourceRunner(
s"stream reader for $pythonExec", 0, Long.MaxValue)
def readArrowRecordBatches(): Iterator[InternalRow] = {
- assert(dataIn.readInt() == SpecialLengths.START_ARROW_STREAM)
+ val status = dataIn.readInt()
+ status match {
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+ action = "prefetchArrowBatches", msg)
+ case SpecialLengths.START_ARROW_STREAM =>
+ case _ =>
+ throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+ action = "prefetchArrowBatches", s"unknown status code $status")
+ }
val reader = new ArrowStreamReader(dataIn, allocator)
val root = reader.getVectorSchemaRoot()
// When input is empty schema can't be read.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
index 35275eb16ebab..f6a3f1b1394f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
@@ -133,7 +133,7 @@ case class TransformWithStateInPandasExec(
val data = groupAndProject(dataIterator, groupingAttributes, child.output, dedupAttributes)
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId,
- groupingKeyExprEncoder, timeMode)
+ groupingKeyExprEncoder, timeMode, isStreaming = true, batchTimestampMs, metrics)
val runner = new TransformWithStateInPandasPythonRunner(
chainedFunc,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index d23f1dbd422a7..b5ec26b401d28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException}
import java.net.ServerSocket
+import java.time.Duration
import scala.collection.mutable
@@ -30,7 +31,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState}
import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall}
-import org.apache.spark.sql.streaming.ValueState
+import org.apache.spark.sql.streaming.{TTLConfig, ValueState}
import org.apache.spark.sql.types.StructType
/**
@@ -153,7 +154,10 @@ class TransformWithStateInPandasStateServer(
case StatefulProcessorCall.MethodCase.GETVALUESTATE =>
val stateName = message.getGetValueState.getStateName
val schema = message.getGetValueState.getSchema
- initializeValueState(stateName, schema)
+ val ttlDurationMs = if (message.getGetValueState.hasTtl) {
+ Some(message.getGetValueState.getTtl.getDurationMs)
+ } else None
+ initializeValueState(stateName, schema, ttlDurationMs)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
@@ -228,10 +232,18 @@ class TransformWithStateInPandasStateServer(
outputStream.write(responseMessageBytes)
}
- private def initializeValueState(stateName: String, schemaString: String): Unit = {
+ private def initializeValueState(
+ stateName: String,
+ schemaString: String,
+ ttlDurationMs: Option[Int]): Unit = {
if (!valueStates.contains(stateName)) {
val schema = StructType.fromString(schemaString)
- val state = statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema))
+ val state = if (ttlDurationMs.isEmpty) {
+ statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema))
+ } else {
+ statefulProcessorHandle.getValueState(
+ stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+ }
val valueRowDeserializer = ExpressionEncoder(schema).resolveAndBind().createDeserializer()
valueStates.put(stateName, (state, schema, valueRowDeserializer))
sendResponse(0)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala
index aa393211a1c15..cb7e71bda84dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.ThreadUtils
@@ -40,9 +41,16 @@ trait AsyncLogPurge extends Logging {
private val purgeRunning = new AtomicBoolean(false)
+ private val statefulMetadataPurgeRunning = new AtomicBoolean(false)
+
protected def purge(threshold: Long): Unit
- protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE)
+ // This method is used to purge the oldest OperatorStateMetadata and StateSchema files
+ // which are written per run.
+ protected def purgeStatefulMetadata(plan: SparkPlan): Unit
+
+ protected lazy val useAsyncPurge: Boolean = sparkSession.sessionState.conf
+ .getConf(SQLConf.ASYNC_LOG_PURGE)
protected def purgeAsync(batchId: Long): Unit = {
if (purgeRunning.compareAndSet(false, true)) {
@@ -62,6 +70,24 @@ trait AsyncLogPurge extends Logging {
}
}
+ protected def purgeStatefulMetadataAsync(plan: SparkPlan): Unit = {
+ if (statefulMetadataPurgeRunning.compareAndSet(false, true)) {
+ asyncPurgeExecutorService.execute(() => {
+ try {
+ purgeStatefulMetadata(plan)
+ } catch {
+ case throwable: Throwable =>
+ logError("Encountered error while performing async log purge", throwable)
+ errorNotifier.markError(throwable)
+ } finally {
+ statefulMetadataPurgeRunning.set(false)
+ }
+ })
+ } else {
+ log.debug("Skipped log purging since there is already one in progress.")
+ }
+ }
+
protected def asyncLogPurgeShutdown(): Unit = {
ThreadUtils.shutdown(asyncPurgeExecutorService)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index d56dfebd61ba1..766caaab2285e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.TimeUnit.NANOSECONDS
+import scala.util.control.NonFatal
+
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -447,10 +450,33 @@ case class FlatMapGroupsWithStateExec(
hasTimedOut,
watermarkPresent)
- // Call function, get the returned objects and convert them to rows
- val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
- numOutputRows += 1
- getOutputRow(obj)
+ def withUserFuncExceptionHandling[T](func: => T): T = {
+ try {
+ func
+ } catch {
+ case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
+ throw FlatMapGroupsWithStateUserFuncException(e)
+ case f: Throwable =>
+ throw f
+ }
+ }
+
+ val mappedIterator = withUserFuncExceptionHandling {
+ func(keyObj, valueObjIter, groupState).map { obj =>
+ numOutputRows += 1
+ getOutputRow(obj)
+ }
+ }
+
+ // Wrap user-provided fns with error handling
+ val wrappedMappedIterator = new Iterator[InternalRow] {
+ override def hasNext: Boolean = {
+ withUserFuncExceptionHandling(mappedIterator.hasNext)
+ }
+
+ override def next(): InternalRow = {
+ withUserFuncExceptionHandling(mappedIterator.next())
+ }
}
// When the iterator is consumed, then write changes to state
@@ -472,7 +498,9 @@ case class FlatMapGroupsWithStateExec(
}
// Return an iterator of rows such that fully consumed, the updated state value will be saved
- CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
+ CompletionIterator[InternalRow, Iterator[InternalRow]](
+ wrappedMappedIterator, onIteratorCompletion
+ )
}
}
}
@@ -544,3 +572,13 @@ object FlatMapGroupsWithStateExec {
}
}
}
+
+
+/**
+ * Exception that wraps the exception thrown in the user provided function in Foreach sink.
+ */
+private[sql] case class FlatMapGroupsWithStateUserFuncException(cause: Throwable)
+ extends SparkException(
+ errorClass = "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR",
+ messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")),
+ cause = cause)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index f59cdca8aefec..053aef6ced3a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -477,7 +477,7 @@ class MicroBatchExecution(
// update offset metadata
nextOffsets.metadata.foreach { metadata =>
- OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf)
+ OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.sessionState.conf)
execCtx.offsetSeqMetadata = OffsetSeqMetadata(
metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf)
@@ -842,6 +842,10 @@ class MicroBatchExecution(
markMicroBatchExecutionStart(execCtx)
+ if (execCtx.previousContext.isEmpty) {
+ purgeStatefulMetadataAsync(execCtx.executionPlan.executedPlan)
+ }
+
val nextBatch =
new Dataset(execCtx.executionPlan, ExpressionEncoder(execCtx.executionPlan.analyzed.schema))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index f0be33ad9a9d8..e1e5b3a7ef88e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -26,6 +26,7 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, SparkDataStream}
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager, SymmetricHashJoinStateManager}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
@@ -100,7 +101,9 @@ object OffsetSeqMetadata extends Logging {
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION,
STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC,
- STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION)
+ STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
+ PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN
+ )
/**
* Default values of relevant configurations that are used for backward compatibility.
@@ -121,7 +124,8 @@ object OffsetSeqMetadata extends Logging {
STREAMING_JOIN_STATE_FORMAT_VERSION.key ->
SymmetricHashJoinStateManager.legacyVersion.toString,
STATE_STORE_COMPRESSION_CODEC.key -> CompressionCodec.LZ4,
- STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false"
+ STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false",
+ PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true"
)
def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
@@ -135,20 +139,21 @@ object OffsetSeqMetadata extends Logging {
}
/** Set the SparkSession configuration with the values in the metadata */
- def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: RuntimeConfig): Unit = {
+ def setSessionConf(metadata: OffsetSeqMetadata, sessionConf: SQLConf): Unit = {
+ val configs = sessionConf.getAllConfs
OffsetSeqMetadata.relevantSQLConfs.map(_.key).foreach { confKey =>
metadata.conf.get(confKey) match {
case Some(valueInMetadata) =>
// Config value exists in the metadata, update the session config with this value
- val optionalValueInSession = sessionConf.getOption(confKey)
- if (optionalValueInSession.isDefined && optionalValueInSession.get != valueInMetadata) {
+ val optionalValueInSession = sessionConf.getConfString(confKey, null)
+ if (optionalValueInSession != null && optionalValueInSession != valueInMetadata) {
logWarning(log"Updating the value of conf '${MDC(CONFIG, confKey)}' in current " +
- log"session from '${MDC(OLD_VALUE, optionalValueInSession.get)}' " +
+ log"session from '${MDC(OLD_VALUE, optionalValueInSession)}' " +
log"to '${MDC(NEW_VALUE, valueInMetadata)}'.")
}
- sessionConf.set(confKey, valueInMetadata)
+ sessionConf.setConfString(confKey, valueInMetadata)
case None =>
// For backward compatibility, if a config was not recorded in the offset log,
@@ -157,14 +162,17 @@ object OffsetSeqMetadata extends Logging {
relevantSQLConfDefaultValues.get(confKey) match {
case Some(defaultValue) =>
- sessionConf.set(confKey, defaultValue)
+ sessionConf.setConfString(confKey, defaultValue)
logWarning(log"Conf '${MDC(CONFIG, confKey)}' was not found in the offset log, " +
log"using default value '${MDC(DEFAULT_VALUE, defaultValue)}'")
case None =>
- val valueStr = sessionConf.getOption(confKey).map { v =>
- s" Using existing session conf value '$v'."
- }.getOrElse { " No value set in session conf." }
+ val value = sessionConf.getConfString(confKey, null)
+ val valueStr = if (value != null) {
+ s" Using existing session conf value '$value'."
+ } else {
+ " No value set in session conf."
+ }
logWarning(log"Conf '${MDC(CONFIG, confKey)}' was not found in the offset log. " +
log"${MDC(TIP, valueStr)}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 6fd58e13366e0..8f030884ad33b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -41,8 +41,10 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream}
import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write}
+import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.StreamingExplainCommand
-import org.apache.spark.sql.execution.streaming.sources.ForeachBatchUserFuncException
+import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchUserFuncException, ForeachUserFuncException}
+import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataV2FileManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
import org.apache.spark.sql.streaming._
@@ -346,9 +348,11 @@ abstract class StreamExecution(
getLatestExecutionContext().updateStatusMessage("Stopped")
case e: Throwable =>
val message = if (e.getMessage == null) "" else e.getMessage
- val cause = if (e.isInstanceOf[ForeachBatchUserFuncException]) {
+ val cause = if (e.isInstanceOf[ForeachBatchUserFuncException] ||
+ e.isInstanceOf[ForeachUserFuncException]) {
// We want to maintain the current way users get the causing exception
- // from the StreamingQueryException. Hence the ForeachBatch exception is unwrapped here.
+ // from the StreamingQueryException.
+ // Hence the ForeachBatch/Foreach exception is unwrapped here.
e.getCause
} else {
e
@@ -367,13 +371,7 @@ abstract class StreamExecution(
messageParameters = Map(
"id" -> id.toString,
"runId" -> runId.toString,
- "message" -> message,
- "queryDebugString" -> toDebugString(includeLogicalPlan = isInitialized),
- "startOffset" -> getLatestExecutionContext().startOffsets.toOffsetSeq(
- sources.toSeq, getLatestExecutionContext().offsetSeqMetadata).toString,
- "endOffset" -> getLatestExecutionContext().endOffsets.toOffsetSeq(
- sources.toSeq, getLatestExecutionContext().offsetSeqMetadata).toString
- ))
+ "message" -> message))
errorClassOpt = e match {
case t: SparkThrowable => Option(t.getErrorClass)
@@ -485,7 +483,7 @@ abstract class StreamExecution(
@throws[TimeoutException]
protected def interruptAndAwaitExecutionThreadTermination(): Unit = {
val timeout = math.max(
- sparkSession.conf.get(SQLConf.STREAMING_STOP_TIMEOUT), 0)
+ sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_STOP_TIMEOUT), 0)
queryExecutionThread.interrupt()
queryExecutionThread.join(timeout)
if (queryExecutionThread.isAlive) {
@@ -690,6 +688,31 @@ abstract class StreamExecution(
offsetLog.purge(threshold)
commitLog.purge(threshold)
}
+
+ protected def purgeStatefulMetadata(plan: SparkPlan): Unit = {
+ plan.collect { case statefulOperator: StatefulOperator =>
+ statefulOperator match {
+ case ssw: StateStoreWriter =>
+ ssw.operatorStateMetadataVersion match {
+ case 2 =>
+ // checkpointLocation of the operator is runId/state, and commitLog path is
+ // runId/commits, so we want the parent of the checkpointLocation to get the
+ // commit log path.
+ val parentCheckpointLocation =
+ new Path(statefulOperator.getStateInfo.checkpointLocation).getParent
+
+ val fileManager = new OperatorStateMetadataV2FileManager(
+ parentCheckpointLocation,
+ sparkSession,
+ ssw
+ )
+ fileManager.purgeMetadataFiles()
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+ }
}
object StreamExecution {
@@ -728,6 +751,7 @@ object StreamExecution {
if e2.getCause != null =>
isInterruptionException(e2.getCause, sc)
case fe: ForeachBatchUserFuncException => isInterruptionException(fe.getCause, sc)
+ case fes: ForeachUserFuncException => isInterruptionException(fes.getCause, sc)
case se: SparkException =>
if (se.getCause == null) {
isCancelledJobGroup(se.getMessage)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
index 6065af10ffe80..8811c59a50745 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
@@ -85,7 +85,7 @@ abstract class SingleKeyTTLStateImpl(
import org.apache.spark.sql.execution.streaming.StateTTLSchema._
- private val ttlColumnFamilyName = s"_ttl_$stateName"
+ private val ttlColumnFamilyName = "$ttl_" + stateName
private val keySchema = getSingleKeyTTLRowSchema(keyExprEnc.schema)
private val keyTTLRowEncoder = new SingleKeyTTLEncoder(keyExprEnc)
@@ -205,7 +205,7 @@ abstract class CompositeKeyTTLStateImpl[K](
import org.apache.spark.sql.execution.streaming.StateTTLSchema._
- private val ttlColumnFamilyName = s"_ttl_$stateName"
+ private val ttlColumnFamilyName = "$ttl_" + stateName
private val keySchema = getCompositeKeyTTLRowSchema(
keyExprEnc.schema, userKeyEncoder.schema
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
index 650f57039a030..82a4226fcfd54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
@@ -30,8 +30,8 @@ import org.apache.spark.util.NextIterator
* Singleton utils class used primarily while interacting with TimerState
*/
object TimerStateUtils {
- val PROC_TIMERS_STATE_NAME = "_procTimers"
- val EVENT_TIMERS_STATE_NAME = "_eventTimers"
+ val PROC_TIMERS_STATE_NAME = "$procTimers"
+ val EVENT_TIMERS_STATE_NAME = "$eventTimers"
val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
index 54c47ec4e6ed8..3e6f122f463d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/WatermarkTracker.scala
@@ -135,7 +135,8 @@ object WatermarkTracker {
// saved in the checkpoint (e.g., old checkpoints), then the default `min` policy is enforced
// through defaults specified in OffsetSeqMetadata.setSessionConf().
val policyName = conf.get(
- SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY, MultipleWatermarkPolicy.DEFAULT_POLICY_NAME)
+ SQLConf.STREAMING_MULTIPLE_WATERMARK_POLICY.key,
+ MultipleWatermarkPolicy.DEFAULT_POLICY_NAME)
new WatermarkTracker(MultipleWatermarkPolicy(policyName))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
index 420c3e3be16d6..273ffa6aefb7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -32,8 +32,9 @@ import org.apache.spark.SparkEnv
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{HOST, PORT}
import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
@@ -57,8 +58,7 @@ class TextSocketContinuousStream(
implicit val defaultFormats: DefaultFormats = DefaultFormats
- private val encoder = ExpressionEncoder.tuple(ExpressionEncoder[String](),
- ExpressionEncoder[Timestamp]())
+ private val encoder = encoderFor(Encoders.tuple(Encoders.STRING, Encoders.TIMESTAMP))
@GuardedBy("this")
private var socket: Socket = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
index bbbe28ec7ab11..c0956a62e59fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
@@ -19,6 +19,9 @@ package org.apache.spark.sql.execution.streaming.sources
import java.util
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.{ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -146,6 +149,9 @@ class ForeachDataWriter[T](
try {
writer.process(rowConverter(record))
} catch {
+ case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
+ errorOrNull = e
+ throw ForeachUserFuncException(e)
case t: Throwable =>
errorOrNull = t
throw t
@@ -172,3 +178,12 @@ class ForeachDataWriter[T](
* An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination.
*/
case object ForeachWriterCommitMessage extends WriterCommitMessage
+
+/**
+ * Exception that wraps the exception thrown in the user provided function in Foreach sink.
+ */
+private[sql] case class ForeachUserFuncException(cause: Throwable)
+ extends SparkException(
+ errorClass = "FOREACH_USER_FUNCTION_ERROR",
+ messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")),
+ cause = cause)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
index 3e68b3975e662..aa2f332afeff4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
@@ -23,14 +23,14 @@ import java.nio.charset.StandardCharsets
import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path}
+import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path, PathFilter}
import org.json4s.{Formats, NoTypeHints}
import org.json4s.jackson.Serialization
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors
-import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, CommitLog, MetadataVersionUtil, OffsetSeqLog}
+import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, CommitLog, MetadataVersionUtil, OffsetSeqLog, StateStoreWriter}
import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS}
import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataUtils.{OperatorStateMetadataReader, OperatorStateMetadataWriter}
@@ -358,3 +358,121 @@ class OperatorStateMetadataV2Reader(
}
}
}
+
+/**
+ * A helper class to manage the metadata files for the operator state checkpoint.
+ * This class is used to manage the metadata files for OperatorStateMetadataV2, and
+ * provides utils to purge the oldest files such that we only keep the metadata files
+ * for which a commit log is present
+ * @param checkpointLocation The root path of the checkpoint directory
+ * @param sparkSession The sparkSession that is used to access the hadoopConf
+ * @param stateStoreWriter The operator that this fileManager is being created for
+ */
+class OperatorStateMetadataV2FileManager(
+ checkpointLocation: Path,
+ sparkSession: SparkSession,
+ stateStoreWriter: StateStoreWriter) extends Logging {
+
+ private val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ private val stateCheckpointPath = new Path(checkpointLocation, "state")
+ private val stateOpIdPath = new Path(
+ stateCheckpointPath, stateStoreWriter.getStateInfo.operatorId.toString)
+ private val commitLog =
+ new CommitLog(sparkSession, new Path(checkpointLocation, "commits").toString)
+ private val stateSchemaPath = stateStoreWriter.stateSchemaDirPath()
+ private val metadataDirPath = OperatorStateMetadataV2.metadataDirPath(stateOpIdPath)
+ private lazy val fm = CheckpointFileManager.create(metadataDirPath, hadoopConf)
+
+ protected def isBatchFile(path: Path): Boolean = {
+ try {
+ path.getName.toLong
+ true
+ } catch {
+ case _: NumberFormatException => false
+ }
+ }
+
+ /**
+ * A `PathFilter` to filter only batch files
+ */
+ protected val batchFilesFilter: PathFilter = (path: Path) => isBatchFile(path)
+
+ private def pathToBatchId(path: Path): Long = {
+ path.getName.toLong
+ }
+
+ def purgeMetadataFiles(): Unit = {
+ val thresholdBatchId = findThresholdBatchId()
+ if (thresholdBatchId != 0) {
+ val earliestBatchIdKept = deleteMetadataFiles(thresholdBatchId)
+ // we need to delete everything from 0 to (earliestBatchIdKept - 1), inclusive
+ deleteSchemaFiles(earliestBatchIdKept - 1)
+ }
+ }
+
+ // We only want to keep the metadata and schema files for which the commit
+ // log is present, so we will delete any file that precedes the batch for the oldest
+ // commit log
+ private def findThresholdBatchId(): Long = {
+ commitLog.listBatchesOnDisk.headOption.getOrElse(0L)
+ }
+
+ private def deleteSchemaFiles(thresholdBatchId: Long): Unit = {
+ val schemaFiles = fm.list(stateSchemaPath).sorted.map(_.getPath)
+ val filesBeforeThreshold = schemaFiles.filter { path =>
+ val batchIdInPath = path.getName.split("_").head.toLong
+ batchIdInPath <= thresholdBatchId
+ }
+ filesBeforeThreshold.foreach { path =>
+ fm.delete(path)
+ }
+ }
+
+ // Deletes all metadata files that are below a thresholdBatchId, except
+ // for the latest metadata file so that we have at least 1 metadata and schema
+ // file at all times per each stateful query
+ // Returns the batchId of the earliest schema file we want to keep
+ private def deleteMetadataFiles(thresholdBatchId: Long): Long = {
+ val metadataFiles = fm.list(metadataDirPath, batchFilesFilter)
+
+ if (metadataFiles.isEmpty) {
+ return -1L // No files to delete
+ }
+
+ // get all the metadata files for which we don't have commit logs
+ val sortedBatchIds = metadataFiles
+ .map(file => pathToBatchId(file.getPath))
+ .filter(_ <= thresholdBatchId)
+ .sorted
+
+ if (sortedBatchIds.isEmpty) {
+ return -1L
+ }
+
+ // we don't want to delete the batchId right before the last one
+ val latestBatchId = sortedBatchIds.last
+
+ metadataFiles.foreach { batchFile =>
+ val batchId = pathToBatchId(batchFile.getPath)
+ if (batchId < latestBatchId) {
+ fm.delete(batchFile.getPath)
+ }
+ }
+ val latestMetadata = OperatorStateMetadataReader.createReader(
+ stateOpIdPath,
+ hadoopConf,
+ 2,
+ latestBatchId
+ ).read()
+
+ // find the batchId of the earliest schema file we need to keep
+ val earliestBatchToKeep = latestMetadata match {
+ case Some(OperatorStateMetadataV2(_, stateStoreInfo, _)) =>
+ val schemaFilePath = stateStoreInfo.head.stateSchemaFilePath
+ new Path(schemaFilePath).getName.split("_").head.toLong
+ case _ => 0
+ }
+
+ earliestBatchToKeep
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 7badc26bf0447..4a2aac43b3331 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -263,7 +263,7 @@ class RocksDB(
logInfo(log"Loading ${MDC(LogKeys.VERSION_NUM, version)}")
try {
if (loadedVersion != version) {
- closeDB()
+ closeDB(ignoreException = false)
// deep copy is needed to avoid race condition
// between maintenance and task threads
fileManager.copyFileMapping()
@@ -407,10 +407,12 @@ class RocksDB(
* Replay change log from the loaded version to the target version.
*/
private def replayChangelog(endVersion: Long): Unit = {
+ logInfo(log"Replaying changelog from version " +
+ log"${MDC(LogKeys.LOADED_VERSION, loadedVersion)} -> " +
+ log"${MDC(LogKeys.END_VERSION, endVersion)}")
for (v <- loadedVersion + 1 to endVersion) {
- logInfo(log"replaying changelog from version " +
- log"${MDC(LogKeys.LOADED_VERSION, loadedVersion)} -> " +
- log"${MDC(LogKeys.END_VERSION, endVersion)}")
+ logInfo(log"Replaying changelog on version " +
+ log"${MDC(LogKeys.VERSION_NUM, v)}")
var changelogReader: StateStoreChangelogReader = null
try {
changelogReader = fileManager.getChangelogReader(v, useColumnFamilies)
@@ -644,15 +646,15 @@ class RocksDB(
// is enabled.
if (shouldForceSnapshot.get()) {
uploadSnapshot()
+ shouldForceSnapshot.set(false)
+ }
+
+ // ensure that changelog files are always written
+ try {
+ assert(changelogWriter.isDefined)
+ changelogWriter.foreach(_.commit())
+ } finally {
changelogWriter = None
- changelogWriter.foreach(_.abort())
- } else {
- try {
- assert(changelogWriter.isDefined)
- changelogWriter.foreach(_.commit())
- } finally {
- changelogWriter = None
- }
}
} else {
assert(changelogWriter.isEmpty)
@@ -927,10 +929,12 @@ class RocksDB(
* @param opType - operation type releasing the lock
*/
private def release(opType: RocksDBOpType): Unit = acquireLock.synchronized {
- logInfo(log"RocksDB instance was released by ${MDC(LogKeys.THREAD, acquiredThreadInfo)} " +
- log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}")
- acquiredThreadInfo = null
- acquireLock.notifyAll()
+ if (acquiredThreadInfo != null) {
+ logInfo(log"RocksDB instance was released by ${MDC(LogKeys.THREAD,
+ acquiredThreadInfo)} " + log"for opType=${MDC(LogKeys.OP_TYPE, opType.toString)}")
+ acquiredThreadInfo = null
+ acquireLock.notifyAll()
+ }
}
private def getDBProperty(property: String): Long = db.getProperty(property).toLong
@@ -941,13 +945,17 @@ class RocksDB(
logInfo(log"Opened DB with conf ${MDC(LogKeys.CONFIG, conf)}")
}
- private def closeDB(): Unit = {
+ private def closeDB(ignoreException: Boolean = true): Unit = {
if (db != null) {
-
// Cancel and wait until all background work finishes
db.cancelAllBackgroundWork(true)
- // Close the DB instance
- db.close()
+ if (ignoreException) {
+ // Close the DB instance
+ db.close()
+ } else {
+ // Close the DB instance and throw the exception if any
+ db.closeE()
+ }
db = null
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 685cc9a1533e4..85f80ce9eb1ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -567,7 +567,7 @@ private[sql] class RocksDBStateStoreProvider
}
// if the column family is not internal and uses reserved characters, throw an exception
- if (!isInternal && colFamilyName.charAt(0) == '_') {
+ if (!isInternal && colFamilyName.charAt(0) == '$') {
throw StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
index 47b1cb90e00a8..721d72b6a0991 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, StatefulOperatorStateInfo}
import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaReader, SchemaWriter}
+import org.apache.spark.sql.execution.streaming.state.StateSchemaCompatibilityChecker.SCHEMA_FORMAT_V3
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.{DataType, StructType}
@@ -95,7 +96,7 @@ class StateSchemaCompatibilityChecker(
stateStoreColFamilySchema: List[StateStoreColFamilySchema],
stateSchemaVersion: Int): Unit = {
// Ensure that schema file path is passed explicitly for schema version 3
- if (stateSchemaVersion == 3 && newSchemaFilePath.isEmpty) {
+ if (stateSchemaVersion == SCHEMA_FORMAT_V3 && newSchemaFilePath.isEmpty) {
throw new IllegalStateException("Schema file path is required for schema version 3")
}
@@ -167,12 +168,12 @@ class StateSchemaCompatibilityChecker(
newStateSchema: List[StateStoreColFamilySchema],
ignoreValueSchema: Boolean,
stateSchemaVersion: Int): Boolean = {
- val existingStateSchemaList = getExistingKeyAndValueSchema().sortBy(_.colFamilyName)
- val newStateSchemaList = newStateSchema.sortBy(_.colFamilyName)
+ val existingStateSchemaList = getExistingKeyAndValueSchema()
+ val newStateSchemaList = newStateSchema
if (existingStateSchemaList.isEmpty) {
// write the schema file if it doesn't exist
- createSchemaFile(newStateSchemaList, stateSchemaVersion)
+ createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion)
true
} else {
// validate if the new schema is compatible with the existing schema
@@ -186,7 +187,13 @@ class StateSchemaCompatibilityChecker(
check(existingStateSchema, newSchema, ignoreValueSchema)
}
}
- false
+ val colFamiliesAddedOrRemoved =
+ (newStateSchemaList.map(_.colFamilyName).toSet != existingSchemaMap.keySet)
+ if (stateSchemaVersion == SCHEMA_FORMAT_V3 && colFamiliesAddedOrRemoved) {
+ createSchemaFile(newStateSchemaList.sortBy(_.colFamilyName), stateSchemaVersion)
+ }
+ // TODO: [SPARK-49535] Write Schema files after schema has changed for StateSchemaV3
+ colFamiliesAddedOrRemoved
}
}
@@ -195,6 +202,9 @@ class StateSchemaCompatibilityChecker(
}
object StateSchemaCompatibilityChecker {
+
+ val SCHEMA_FORMAT_V3: Int = 3
+
private def disallowBinaryInequalityColumn(schema: StructType): Unit = {
if (!UnsafeRowUtils.isBinaryStable(schema)) {
throw new SparkUnsupportedOperationException(
@@ -274,10 +284,31 @@ object StateSchemaCompatibilityChecker {
if (storeConf.stateSchemaCheckEnabled && result.isDefined) {
throw result.get
}
- val schemaFileLocation = newSchemaFilePath match {
- case Some(path) => path.toString
- case None => checker.schemaFileLocation.toString
+ val schemaFileLocation = if (evolvedSchema) {
+ // if we are using the state schema v3, and we have
+ // evolved schema, this newSchemaFilePath should be defined
+ // and we want to populate the metadata with this file
+ if (stateSchemaVersion == SCHEMA_FORMAT_V3) {
+ newSchemaFilePath.get.toString
+ } else {
+ // if we are using any version less than v3, we have written
+ // the schema to this static location, which we will return
+ checker.schemaFileLocation.toString
+ }
+ } else {
+ // if we have not evolved schema (there has been a previous schema)
+ // and we are using state schema v3, this file path would be defined
+ // so we would just populate the next run's metadata file with this
+ // file path
+ if (stateSchemaVersion == SCHEMA_FORMAT_V3) {
+ oldSchemaFilePath.get.toString
+ } else {
+ // if we are using any version less than v3, we have written
+ // the schema to this static location, which we will return
+ checker.schemaFileLocation.toString
+ }
}
+
StateSchemaValidationResult(evolvedSchema, schemaFileLocation)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index c092adb354c2d..3cb41710a22c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -166,17 +166,19 @@ trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: Sp
"number of state store instances")
) ++ stateStoreCustomMetrics ++ pythonMetrics
- def stateSchemaFilePath(storeName: Option[String] = None): Path = {
- def stateInfo = getStateInfo
+ // This method is only used to fetch the state schema directory path for
+ // operators that use StateSchemaV3, as prior versions only use a single
+ // set file path.
+ def stateSchemaDirPath(
+ storeName: Option[String] = None): Path = {
val stateCheckpointPath =
new Path(getStateInfo.checkpointLocation,
- s"${stateInfo.operatorId.toString}")
+ s"${getStateInfo.operatorId.toString}")
storeName match {
case Some(storeName) =>
- val storeNamePath = new Path(stateCheckpointPath, storeName)
- new Path(new Path(storeNamePath, "_metadata"), "schema")
+ new Path(new Path(stateCheckpointPath, "_stateSchema"), storeName)
case None =>
- new Path(new Path(stateCheckpointPath, "_metadata"), "schema")
+ new Path(new Path(stateCheckpointPath, "_stateSchema"), "default")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index fd3df372a2d56..192b5bf65c4c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder, ProductEncoder}
/**
* An aggregator that uses a single associative and commutative reduce function. This reduce
@@ -46,10 +47,10 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T])
- override def bufferEncoder: Encoder[(Boolean, T)] =
- ExpressionEncoder.tuple(
- ExpressionEncoder[Boolean](),
- encoder.asInstanceOf[ExpressionEncoder[T]])
+ override def bufferEncoder: Encoder[(Boolean, T)] = {
+ ProductEncoder.tuple(Seq(PrimitiveBooleanEncoder, encoder.asInstanceOf[AgnosticEncoder[T]]))
+ .asInstanceOf[Encoder[(Boolean, T)]]
+ }
override def outputEncoder: Encoder[T] = encoder
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 99287bddb5104..0d0258f11efb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.internal
import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _}
import org.apache.spark.sql.artifact.ArtifactManager
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, TableFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.Optimizer
@@ -205,6 +205,8 @@ abstract class BaseSessionStateBuilder(
new ResolveSessionCatalog(this.catalogManager) +:
ResolveWriteToStream +:
new EvalSubqueriesForTimeTravel +:
+ new ResolveTranspose(session) +:
+ new InvokeProcedures(session) +:
customResolutionRules
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index 8b4fa90b3119d..52b8d35e2fbf8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{ColumnDefinition, CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.catalog.{CatalogManager, SupportsNamespaces, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper}
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -671,12 +672,9 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
} else {
CatalogTableType.MANAGED
}
- val location = if (storage.locationUri.isDefined) {
- val locationStr = storage.locationUri.get.toString
- Some(locationStr)
- } else {
- None
- }
+
+ // The location in UnresolvedTableSpec should be the original user-provided path string.
+ val location = CaseInsensitiveMap(options).get("path")
val newOptions = OptionList(options.map { case (key, value) =>
(key, Literal(value).asInstanceOf[Expression])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala
index 7248a2d3f056e..f0eef9ae1cbb0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala
@@ -426,8 +426,10 @@ final class DataFrameWriterImpl[T] private[sql](ds: Dataset[T]) extends DataFram
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val session = df.sparkSession
- val canUseV2 = lookupV2Provider().isDefined ||
- df.sparkSession.sessionState.conf.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined
+ val canUseV2 = lookupV2Provider().isDefined || (df.sparkSession.sessionState.conf.getConf(
+ SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined &&
+ !df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME)
+ .isInstanceOf[DelegatingCatalogExtension])
session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala
new file mode 100644
index 0000000000000..0a19e6c47afa9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterV2Impl.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import java.util
+
+import scala.collection.mutable
+import scala.jdk.CollectionConverters.MapHasAsScala
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.{Column, DataFrame, DataFrameWriterV2, Dataset}
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege._
+import org.apache.spark.sql.connector.expressions._
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.types.IntegerType
+
+/**
+ * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API.
+ *
+ * @since 3.0.0
+ */
+@Experimental
+final class DataFrameWriterV2Impl[T] private[sql](table: String, ds: Dataset[T])
+ extends DataFrameWriterV2[T] {
+
+ private val df: DataFrame = ds.toDF()
+
+ private val sparkSession = ds.sparkSession
+ import sparkSession.expression
+
+ private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
+
+ private val logicalPlan = df.queryExecution.logical
+
+ private var provider: Option[String] = None
+
+ private val options = new mutable.HashMap[String, String]()
+
+ private val properties = new mutable.HashMap[String, String]()
+
+ private var partitioning: Option[Seq[Transform]] = None
+
+ private var clustering: Option[ClusterByTransform] = None
+
+ /** @inheritdoc */
+ override def using(provider: String): this.type = {
+ this.provider = Some(provider)
+ this
+ }
+
+ /** @inheritdoc */
+ override def option(key: String, value: String): this.type = {
+ this.options.put(key, value)
+ this
+ }
+
+ /** @inheritdoc */
+ override def options(options: scala.collection.Map[String, String]): this.type = {
+ options.foreach {
+ case (key, value) =>
+ this.options.put(key, value)
+ }
+ this
+ }
+
+ /** @inheritdoc */
+ override def options(options: util.Map[String, String]): this.type = {
+ this.options(options.asScala)
+ this
+ }
+
+ /** @inheritdoc */
+ override def tableProperty(property: String, value: String): this.type = {
+ this.properties.put(property, value)
+ this
+ }
+
+
+ /** @inheritdoc */
+ @scala.annotation.varargs
+ override def partitionedBy(column: Column, columns: Column*): this.type = {
+ def ref(name: String): NamedReference = LogicalExpressions.parseReference(name)
+
+ val asTransforms = (column +: columns).map(expression).map {
+ case PartitionTransform.YEARS(Seq(attr: Attribute)) =>
+ LogicalExpressions.years(ref(attr.name))
+ case PartitionTransform.MONTHS(Seq(attr: Attribute)) =>
+ LogicalExpressions.months(ref(attr.name))
+ case PartitionTransform.DAYS(Seq(attr: Attribute)) =>
+ LogicalExpressions.days(ref(attr.name))
+ case PartitionTransform.HOURS(Seq(attr: Attribute)) =>
+ LogicalExpressions.hours(ref(attr.name))
+ case PartitionTransform.BUCKET(Seq(Literal(numBuckets: Int, IntegerType), attr: Attribute)) =>
+ LogicalExpressions.bucket(numBuckets, Array(ref(attr.name)))
+ case PartitionTransform.BUCKET(Seq(numBuckets, e)) =>
+ throw QueryCompilationErrors.invalidBucketsNumberError(numBuckets.toString, e.toString)
+ case attr: Attribute =>
+ LogicalExpressions.identity(ref(attr.name))
+ case expr =>
+ throw QueryCompilationErrors.invalidPartitionTransformationError(expr)
+ }
+
+ this.partitioning = Some(asTransforms)
+ validatePartitioning()
+ this
+ }
+
+ /** @inheritdoc */
+ @scala.annotation.varargs
+ override def clusterBy(colName: String, colNames: String*): this.type = {
+ this.clustering =
+ Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col))))
+ validatePartitioning()
+ this
+ }
+
+ /**
+ * Validate that clusterBy is not used with partitionBy.
+ */
+ private def validatePartitioning(): Unit = {
+ if (partitioning.nonEmpty && clustering.nonEmpty) {
+ throw QueryCompilationErrors.clusterByWithPartitionedBy()
+ }
+ }
+
+ /** @inheritdoc */
+ override def create(): Unit = {
+ val tableSpec = UnresolvedTableSpec(
+ properties = properties.toMap,
+ provider = provider,
+ optionExpression = OptionList(Seq.empty),
+ location = None,
+ comment = None,
+ serde = None,
+ external = false)
+ runCommand(
+ CreateTableAsSelect(
+ UnresolvedIdentifier(tableName),
+ partitioning.getOrElse(Seq.empty) ++ clustering,
+ logicalPlan,
+ tableSpec,
+ options.toMap,
+ false))
+ }
+
+ /** @inheritdoc */
+ override def replace(): Unit = {
+ internalReplace(orCreate = false)
+ }
+
+ /** @inheritdoc */
+ override def createOrReplace(): Unit = {
+ internalReplace(orCreate = true)
+ }
+
+ /** @inheritdoc */
+ @throws(classOf[NoSuchTableException])
+ def append(): Unit = {
+ val append = AppendData.byName(
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
+ logicalPlan, options.toMap)
+ runCommand(append)
+ }
+
+ /** @inheritdoc */
+ @throws(classOf[NoSuchTableException])
+ def overwrite(condition: Column): Unit = {
+ val overwrite = OverwriteByExpression.byName(
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
+ logicalPlan, expression(condition), options.toMap)
+ runCommand(overwrite)
+ }
+
+ /** @inheritdoc */
+ @throws(classOf[NoSuchTableException])
+ def overwritePartitions(): Unit = {
+ val dynamicOverwrite = OverwritePartitionsDynamic.byName(
+ UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
+ logicalPlan, options.toMap)
+ runCommand(dynamicOverwrite)
+ }
+
+ /**
+ * Wrap an action to track the QueryExecution and time cost, then report to the user-registered
+ * callback functions.
+ */
+ private def runCommand(command: LogicalPlan): Unit = {
+ val qe = new QueryExecution(sparkSession, command, df.queryExecution.tracker)
+ qe.assertCommandExecuted()
+ }
+
+ private def internalReplace(orCreate: Boolean): Unit = {
+ val tableSpec = UnresolvedTableSpec(
+ properties = properties.toMap,
+ provider = provider,
+ optionExpression = OptionList(Seq.empty),
+ location = None,
+ comment = None,
+ serde = None,
+ external = false)
+ runCommand(ReplaceTableAsSelect(
+ UnresolvedIdentifier(tableName),
+ partitioning.getOrElse(Seq.empty) ++ clustering,
+ logicalPlan,
+ tableSpec,
+ writeOptions = options.toMap,
+ orCreate = orCreate))
+ }
+}
+
+private object PartitionTransform {
+ class ExtractTransform(name: String) {
+ private val NAMES = Seq(name)
+
+ def unapply(e: Expression): Option[Seq[Expression]] = e match {
+ case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children)
+ case _ => None
+ }
+ }
+
+ val HOURS = new ExtractTransform("hours")
+ val DAYS = new ExtractTransform("days")
+ val MONTHS = new ExtractTransform("months")
+ val YEARS = new ExtractTransform("years")
+ val BUCKET = new ExtractTransform("bucket")
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
new file mode 100644
index 0000000000000..bb8146e3e0e33
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/MergeIntoWriterImpl.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkRuntimeException
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.{Column, DataFrame, Dataset, MergeIntoWriter}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.functions.expr
+
+/**
+ * `MergeIntoWriter` provides methods to define and execute merge actions based
+ * on specified conditions.
+ *
+ * @tparam T the type of data in the Dataset.
+ * @param table the name of the target table for the merge operation.
+ * @param ds the source Dataset to merge into the target table.
+ * @param on the merge condition.
+ *
+ * @since 4.0.0
+ */
+@Experimental
+class MergeIntoWriterImpl[T] private[sql] (table: String, ds: Dataset[T], on: Column)
+ extends MergeIntoWriter[T] {
+
+ private val df: DataFrame = ds.toDF()
+
+ private[sql] val sparkSession = ds.sparkSession
+ import sparkSession.RichColumn
+
+ private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
+
+ private val logicalPlan = df.queryExecution.logical
+
+ private[sql] val matchedActions = mutable.Buffer.empty[MergeAction]
+ private[sql] val notMatchedActions = mutable.Buffer.empty[MergeAction]
+ private[sql] val notMatchedBySourceActions = mutable.Buffer.empty[MergeAction]
+
+ /** @inheritdoc */
+ def merge(): Unit = {
+ if (matchedActions.isEmpty && notMatchedActions.isEmpty && notMatchedBySourceActions.isEmpty) {
+ throw new SparkRuntimeException(
+ errorClass = "NO_MERGE_ACTION_SPECIFIED",
+ messageParameters = Map.empty)
+ }
+
+ val merge = MergeIntoTable(
+ UnresolvedRelation(tableName).requireWritePrivileges(MergeIntoTable.getWritePrivileges(
+ matchedActions, notMatchedActions, notMatchedBySourceActions)),
+ logicalPlan,
+ on.expr,
+ matchedActions.toSeq,
+ notMatchedActions.toSeq,
+ notMatchedBySourceActions.toSeq,
+ schemaEvolutionEnabled)
+ val qe = sparkSession.sessionState.executePlan(merge)
+ qe.assertCommandExecuted()
+ }
+
+ override protected[sql] def insertAll(condition: Option[Column]): MergeIntoWriter[T] = {
+ this.notMatchedActions += InsertStarAction(condition.map(_.expr))
+ this
+ }
+
+ override protected[sql] def insert(
+ condition: Option[Column],
+ map: Map[String, Column]): MergeIntoWriter[T] = {
+ this.notMatchedActions += InsertAction(condition.map(_.expr), mapToAssignments(map))
+ this
+ }
+
+ override protected[sql] def updateAll(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(UpdateStarAction(condition.map(_.expr)), notMatchedBySource)
+ }
+
+ override protected[sql] def update(
+ condition: Option[Column],
+ map: Map[String, Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(
+ UpdateAction(condition.map(_.expr), mapToAssignments(map)),
+ notMatchedBySource)
+ }
+
+ override protected[sql] def delete(
+ condition: Option[Column],
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ appendUpdateDeleteAction(DeleteAction(condition.map(_.expr)), notMatchedBySource)
+ }
+
+ private def appendUpdateDeleteAction(
+ action: MergeAction,
+ notMatchedBySource: Boolean): MergeIntoWriter[T] = {
+ if (notMatchedBySource) {
+ notMatchedBySourceActions += action
+ } else {
+ matchedActions += action
+ }
+ this
+ }
+
+ private def mapToAssignments(map: Map[String, Column]): Seq[Assignment] = {
+ map.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala
similarity index 51%
rename from sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala
index ed8cf4f121f03..ca439cdb89958 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala
@@ -15,15 +15,15 @@
* limitations under the License.
*/
-package org.apache.spark.sql
+package org.apache.spark.sql.internal
import scala.jdk.CollectionConverters._
import org.apache.spark.SPARK_DOC_ROOT
import org.apache.spark.annotation.Stable
-import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry}
+import org.apache.spark.internal.config.ConfigEntry
+import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.internal.SQLConf
/**
* Runtime configuration interface for Spark. To access this, use `SparkSession.conf`.
@@ -33,89 +33,26 @@ import org.apache.spark.sql.internal.SQLConf
* @since 2.0.0
*/
@Stable
-class RuntimeConfig private[sql](val sqlConf: SQLConf = new SQLConf) {
+class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends RuntimeConfig {
- /**
- * Sets the given Spark runtime configuration property.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
def set(key: String, value: String): Unit = {
requireNonStaticConf(key)
sqlConf.setConfString(key, value)
}
- /**
- * Sets the given Spark runtime configuration property.
- *
- * @since 2.0.0
- */
- def set(key: String, value: Boolean): Unit = {
- set(key, value.toString)
- }
-
- /**
- * Sets the given Spark runtime configuration property.
- *
- * @since 2.0.0
- */
- def set(key: String, value: Long): Unit = {
- set(key, value.toString)
- }
-
- /**
- * Sets the given Spark runtime configuration property.
- */
- private[sql] def set[T](entry: ConfigEntry[T], value: T): Unit = {
- requireNonStaticConf(entry.key)
- sqlConf.setConf(entry, value)
- }
-
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- *
- * @throws java.util.NoSuchElementException if the key is not set and does not have a default
- * value
- * @since 2.0.0
- */
+ /** @inheritdoc */
@throws[NoSuchElementException]("if the key is not set")
def get(key: String): String = {
sqlConf.getConfString(key)
}
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
def get(key: String, default: String): String = {
sqlConf.getConfString(key, default)
}
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- */
- @throws[NoSuchElementException]("if the key is not set")
- private[sql] def get[T](entry: ConfigEntry[T]): T = {
- sqlConf.getConf(entry)
- }
-
- private[sql] def get[T](entry: OptionalConfigEntry[T]): Option[T] = {
- sqlConf.getConf(entry)
- }
-
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- */
- private[sql] def get[T](entry: ConfigEntry[T], default: T): T = {
- sqlConf.getConf(entry, default)
- }
-
- /**
- * Returns all properties set in this conf.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
def getAll: Map[String, String] = {
sqlConf.getAllConfs
}
@@ -124,36 +61,20 @@ class RuntimeConfig private[sql](val sqlConf: SQLConf = new SQLConf) {
getAll.asJava
}
- /**
- * Returns the value of Spark runtime configuration property for the given key.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
def getOption(key: String): Option[String] = {
try Option(get(key)) catch {
case _: NoSuchElementException => None
}
}
- /**
- * Resets the configuration property for the given key.
- *
- * @since 2.0.0
- */
+ /** @inheritdoc */
def unset(key: String): Unit = {
requireNonStaticConf(key)
sqlConf.unsetConf(key)
}
- /**
- * Indicates whether the configuration property with the given key
- * is modifiable in the current session.
- *
- * @return `true` if the configuration property is modifiable. For static SQL, Spark Core,
- * invalid (not existing) and other non-modifiable configuration properties,
- * the returned value is `false`.
- * @since 2.4.0
- */
+ /** @inheritdoc */
def isModifiable(key: String): Boolean = sqlConf.isModifiable(key)
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
index b6340a35e7703..23ceb8135fa8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.internal
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
@@ -25,10 +27,10 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
private[sql] object TypedAggUtils {
def aggKeyColumn[A](
- encoder: ExpressionEncoder[A],
+ encoder: Encoder[A],
groupingAttributes: Seq[Attribute]): NamedExpression = {
- if (!encoder.isSerializedAsStructForTopLevel) {
- assert(groupingAttributes.length == 1)
+ val agnosticEncoder = agnosticEncoderFor(encoder)
+ if (!agnosticEncoder.isStruct) {
if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
groupingAttributes.head
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index f2b626490d13c..785bf5b13aa78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -46,12 +46,33 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
// See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
private val supportedAggregateFunctions =
Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions
- private val supportedFunctions = supportedAggregateFunctions
+ private val supportedFunctions = supportedAggregateFunctions ++ Set("DATE_ADD", "DATE_DIFF")
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)
class MySQLSQLBuilder extends JDBCSQLBuilder {
+ override def visitExtract(field: String, source: String): String = {
+ field match {
+ case "DAY_OF_YEAR" => s"DAYOFYEAR($source)"
+ case "YEAR_OF_WEEK" => s"EXTRACT(YEAR FROM $source)"
+ // WEEKDAY uses Monday = 0, Tuesday = 1, ... and ISO standard is Monday = 1, ...,
+ // so we use the formula (WEEKDAY + 1) to follow the ISO standard.
+ case "DAY_OF_WEEK" => s"(WEEKDAY($source) + 1)"
+ case _ => super.visitExtract(field, source)
+ }
+ }
+
+ override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
+ funcName match {
+ case "DATE_ADD" =>
+ s"DATE_ADD(${inputs(0)}, INTERVAL ${inputs(1)} DAY)"
+ case "DATE_DIFF" =>
+ s"DATEDIFF(${inputs(0)}, ${inputs(1)})"
+ case _ => super.visitSQLFunction(funcName, inputs)
+ }
+ }
+
override def visitSortOrder(
sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
(sortDirection, nullOrdering) match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index 7085366c3b7a3..af9fd5464277c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -140,13 +140,20 @@ class SingleStatementExec(
* Implements recursive iterator logic over all child execution nodes.
* @param collection
* Collection of child execution nodes.
+ * @param label
+ * Label set by user or None otherwise.
*/
-abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundStatementExec])
+abstract class CompoundNestedStatementIteratorExec(
+ collection: Seq[CompoundStatementExec],
+ label: Option[String] = None)
extends NonLeafStatementExec {
private var localIterator = collection.iterator
private var curr = if (localIterator.hasNext) Some(localIterator.next()) else None
+ /** Used to stop the iteration in cases when LEAVE statement is encountered. */
+ private var stopIteration = false
+
private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
override def hasNext: Boolean = {
@@ -157,7 +164,7 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
case _ => throw SparkException.internalError(
"Unknown statement type encountered during SQL script interpretation.")
}
- localIterator.hasNext || childHasNext
+ !stopIteration && (localIterator.hasNext || childHasNext)
}
@scala.annotation.tailrec
@@ -165,12 +172,28 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
curr match {
case None => throw SparkException.internalError(
"No more elements to iterate through in the current SQL compound statement.")
+ case Some(leaveStatement: LeaveStatementExec) =>
+ handleLeaveStatement(leaveStatement)
+ curr = None
+ leaveStatement
+ case Some(iterateStatement: IterateStatementExec) =>
+ handleIterateStatement(iterateStatement)
+ curr = None
+ iterateStatement
case Some(statement: LeafStatementExec) =>
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
statement
case Some(body: NonLeafStatementExec) =>
if (body.getTreeIterator.hasNext) {
- body.getTreeIterator.next()
+ body.getTreeIterator.next() match {
+ case leaveStatement: LeaveStatementExec =>
+ handleLeaveStatement(leaveStatement)
+ leaveStatement
+ case iterateStatement: IterateStatementExec =>
+ handleIterateStatement(iterateStatement)
+ iterateStatement
+ case other => other
+ }
} else {
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
next()
@@ -187,6 +210,37 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
collection.foreach(_.reset())
localIterator = collection.iterator
curr = if (localIterator.hasNext) Some(localIterator.next()) else None
+ stopIteration = false
+ }
+
+ /** Actions to do when LEAVE statement is encountered, to stop the execution of this compound. */
+ private def handleLeaveStatement(leaveStatement: LeaveStatementExec): Unit = {
+ if (!leaveStatement.hasBeenMatched) {
+ // Stop the iteration.
+ stopIteration = true
+
+ // TODO: Variable cleanup (once we add SQL script execution logic).
+ // TODO: Add interpreter tests as well.
+
+ // Check if label has been matched.
+ leaveStatement.hasBeenMatched = label.isDefined && label.get.equals(leaveStatement.label)
+ }
+ }
+
+ /**
+ * Actions to do when ITERATE statement is encountered, to stop the execution of this compound.
+ */
+ private def handleIterateStatement(iterateStatement: IterateStatementExec): Unit = {
+ if (!iterateStatement.hasBeenMatched) {
+ // Stop the iteration.
+ stopIteration = true
+
+ // TODO: Variable cleanup (once we add SQL script execution logic).
+ // TODO: Add interpreter tests as well.
+
+ // No need to check if label has been matched, since ITERATE statement is already
+ // not allowed in CompoundBody.
+ }
}
}
@@ -194,9 +248,11 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
* Executable node for CompoundBody.
* @param statements
* Executable nodes for nested statements within the CompoundBody.
+ * @param label
+ * Label set by user to CompoundBody or None otherwise.
*/
-class CompoundBodyExec(statements: Seq[CompoundStatementExec])
- extends CompoundNestedStatementIteratorExec(statements)
+class CompoundBodyExec(statements: Seq[CompoundStatementExec], label: Option[String] = None)
+ extends CompoundNestedStatementIteratorExec(statements, label)
/**
* Executable node for IfElseStatement.
@@ -277,11 +333,13 @@ class IfElseStatementExec(
* Executable node for WhileStatement.
* @param condition Executable node for the condition.
* @param body Executable node for the body.
+ * @param label Label set to WhileStatement by user or None otherwise.
* @param session Spark session that SQL script is executed within.
*/
class WhileStatementExec(
condition: SingleStatementExec,
body: CompoundBodyExec,
+ label: Option[String],
session: SparkSession) extends NonLeafStatementExec {
private object WhileState extends Enumeration {
@@ -308,6 +366,26 @@ class WhileStatementExec(
condition
case WhileState.Body =>
val retStmt = body.getTreeIterator.next()
+
+ // Handle LEAVE or ITERATE statement if it has been encountered.
+ retStmt match {
+ case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched =>
+ if (label.contains(leaveStatementExec.label)) {
+ leaveStatementExec.hasBeenMatched = true
+ }
+ curr = None
+ return retStmt
+ case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched =>
+ if (label.contains(iterStatementExec.label)) {
+ iterStatementExec.hasBeenMatched = true
+ }
+ state = WhileState.Condition
+ curr = Some(condition)
+ condition.reset()
+ return retStmt
+ case _ =>
+ }
+
if (!body.getTreeIterator.hasNext) {
state = WhileState.Condition
curr = Some(condition)
@@ -326,3 +404,191 @@ class WhileStatementExec(
body.reset()
}
}
+
+/**
+ * Executable node for CaseStatement.
+ * @param conditions Collection of executable conditions which correspond to WHEN clauses.
+ * @param conditionalBodies Collection of executable bodies that have a corresponding condition,
+ * in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
+ * @param session Spark session that SQL script is executed within.
+ */
+class CaseStatementExec(
+ conditions: Seq[SingleStatementExec],
+ conditionalBodies: Seq[CompoundBodyExec],
+ elseBody: Option[CompoundBodyExec],
+ session: SparkSession) extends NonLeafStatementExec {
+ private object CaseState extends Enumeration {
+ val Condition, Body = Value
+ }
+
+ private var state = CaseState.Condition
+ private var curr: Option[CompoundStatementExec] = Some(conditions.head)
+
+ private var clauseIdx: Int = 0
+ private val conditionsCount = conditions.length
+
+ private lazy val treeIterator: Iterator[CompoundStatementExec] =
+ new Iterator[CompoundStatementExec] {
+ override def hasNext: Boolean = curr.nonEmpty
+
+ override def next(): CompoundStatementExec = state match {
+ case CaseState.Condition =>
+ val condition = curr.get.asInstanceOf[SingleStatementExec]
+ if (evaluateBooleanCondition(session, condition)) {
+ state = CaseState.Body
+ curr = Some(conditionalBodies(clauseIdx))
+ } else {
+ clauseIdx += 1
+ if (clauseIdx < conditionsCount) {
+ // There are WHEN clauses remaining.
+ state = CaseState.Condition
+ curr = Some(conditions(clauseIdx))
+ } else if (elseBody.isDefined) {
+ // ELSE clause exists.
+ state = CaseState.Body
+ curr = Some(elseBody.get)
+ } else {
+ // No remaining clauses.
+ curr = None
+ }
+ }
+ condition
+ case CaseState.Body =>
+ assert(curr.get.isInstanceOf[CompoundBodyExec])
+ val currBody = curr.get.asInstanceOf[CompoundBodyExec]
+ val retStmt = currBody.getTreeIterator.next()
+ if (!currBody.getTreeIterator.hasNext) {
+ curr = None
+ }
+ retStmt
+ }
+ }
+
+ override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+ override def reset(): Unit = {
+ state = CaseState.Condition
+ curr = Some(conditions.head)
+ clauseIdx = 0
+ conditions.foreach(c => c.reset())
+ conditionalBodies.foreach(b => b.reset())
+ elseBody.foreach(b => b.reset())
+ }
+}
+
+/**
+ * Executable node for RepeatStatement.
+ * @param condition Executable node for the condition - evaluates to a row with a single boolean
+ * expression, otherwise throws an exception
+ * @param body Executable node for the body.
+ * @param label Label set to RepeatStatement by user, None if not set
+ * @param session Spark session that SQL script is executed within.
+ */
+class RepeatStatementExec(
+ condition: SingleStatementExec,
+ body: CompoundBodyExec,
+ label: Option[String],
+ session: SparkSession) extends NonLeafStatementExec {
+
+ private object RepeatState extends Enumeration {
+ val Condition, Body = Value
+ }
+
+ private var state = RepeatState.Body
+ private var curr: Option[CompoundStatementExec] = Some(body)
+
+ private lazy val treeIterator: Iterator[CompoundStatementExec] =
+ new Iterator[CompoundStatementExec] {
+ override def hasNext: Boolean = curr.nonEmpty
+
+ override def next(): CompoundStatementExec = state match {
+ case RepeatState.Condition =>
+ val condition = curr.get.asInstanceOf[SingleStatementExec]
+ if (!evaluateBooleanCondition(session, condition)) {
+ state = RepeatState.Body
+ curr = Some(body)
+ body.reset()
+ } else {
+ curr = None
+ }
+ condition
+ case RepeatState.Body =>
+ val retStmt = body.getTreeIterator.next()
+
+ retStmt match {
+ case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched =>
+ if (label.contains(leaveStatementExec.label)) {
+ leaveStatementExec.hasBeenMatched = true
+ }
+ curr = None
+ return retStmt
+ case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched =>
+ if (label.contains(iterStatementExec.label)) {
+ iterStatementExec.hasBeenMatched = true
+ }
+ state = RepeatState.Condition
+ curr = Some(condition)
+ condition.reset()
+ return retStmt
+ case _ =>
+ }
+
+ if (!body.getTreeIterator.hasNext) {
+ state = RepeatState.Condition
+ curr = Some(condition)
+ condition.reset()
+ }
+ retStmt
+ }
+ }
+
+ override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+ override def reset(): Unit = {
+ state = RepeatState.Body
+ curr = Some(body)
+ body.reset()
+ condition.reset()
+ }
+}
+
+/**
+ * Executable node for LeaveStatement.
+ * @param label Label of the compound or loop to leave.
+ */
+class LeaveStatementExec(val label: String) extends LeafStatementExec {
+ /**
+ * Label specified in the LEAVE statement might not belong to the immediate surrounding compound,
+ * but to the any surrounding compound.
+ * Iteration logic is recursive, i.e. when iterating through the compound, if another
+ * compound is encountered, next() will be called to iterate its body. The same logic
+ * is applied to any other compound down the traversal tree.
+ * In such cases, when LEAVE statement is encountered (as the leaf of the traversal tree),
+ * it will be propagated upwards and the logic will try to match it to the labels of
+ * surrounding compounds.
+ * Once the match is found, this flag is set to true to indicate that search should be stopped.
+ */
+ var hasBeenMatched: Boolean = false
+ override def reset(): Unit = hasBeenMatched = false
+}
+
+/**
+ * Executable node for ITERATE statement.
+ * @param label Label of the loop to iterate.
+ */
+class IterateStatementExec(val label: String) extends LeafStatementExec {
+ /**
+ * Label specified in the ITERATE statement might not belong to the immediate compound,
+ * but to the any surrounding compound.
+ * Iteration logic is recursive, i.e. when iterating through the compound, if another
+ * compound is encountered, next() will be called to iterate its body. The same logic
+ * is applied to any other compound down the tree.
+ * In such cases, when ITERATE statement is encountered (as the leaf of the traversal tree),
+ * it will be propagated upwards and the logic will try to match it to the labels of
+ * surrounding compounds.
+ * Once the match is found, this flag is set to true to indicate that search should be stopped.
+ */
+ var hasBeenMatched: Boolean = false
+ override def reset(): Unit = hasBeenMatched = false
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 08b4f97286280..917b4d6f45ee0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.{CompoundBody, CompoundPlanStatement, IfElseStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, RepeatStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.Origin
@@ -71,9 +71,9 @@ case class SqlScriptingInterpreter() {
private def transformTreeIntoExecutable(
node: CompoundPlanStatement, session: SparkSession): CompoundStatementExec =
node match {
- case body: CompoundBody =>
+ case CompoundBody(collection, label) =>
// TODO [SPARK-48530]: Current logic doesn't support scoped variables and shadowing.
- val variables = body.collection.flatMap {
+ val variables = collection.flatMap {
case st: SingleStatement => getDeclareVarNameFromPlan(st.parsedPlan)
case _ => None
}
@@ -82,7 +82,9 @@ case class SqlScriptingInterpreter() {
.map(new SingleStatementExec(_, Origin(), isInternal = true))
.reverse
new CompoundBodyExec(
- body.collection.map(st => transformTreeIntoExecutable(st, session)) ++ dropVariables)
+ collection.map(st => transformTreeIntoExecutable(st, session)) ++ dropVariables,
+ label)
+
case IfElseStatement(conditions, conditionalBodies, elseBody) =>
val conditionsExec = conditions.map(condition =>
new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false))
@@ -92,12 +94,38 @@ case class SqlScriptingInterpreter() {
transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec])
new IfElseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session)
- case WhileStatement(condition, body, _) =>
+
+ case CaseStatement(conditions, conditionalBodies, elseBody) =>
+ val conditionsExec = conditions.map(condition =>
+ // todo: what to put here for isInternal, in case of simple case statement
+ new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false))
+ val conditionalBodiesExec = conditionalBodies.map(body =>
+ transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec])
+ val unconditionalBodiesExec = elseBody.map(body =>
+ transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec])
+ new CaseStatementExec(
+ conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session)
+
+ case WhileStatement(condition, body, label) =>
val conditionExec =
new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false)
val bodyExec =
transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]
- new WhileStatementExec(conditionExec, bodyExec, session)
+ new WhileStatementExec(conditionExec, bodyExec, label, session)
+
+ case RepeatStatement(condition, body, label) =>
+ val conditionExec =
+ new SingleStatementExec(condition.parsedPlan, condition.origin, isInternal = false)
+ val bodyExec =
+ transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec]
+ new RepeatStatementExec(conditionExec, bodyExec, label, session)
+
+ case leaveStatement: LeaveStatement =>
+ new LeaveStatementExec(leaveStatement.label)
+
+ case iterateStatement: IterateStatement =>
+ new IterateStatementExec(iterateStatement.label)
+
case sparkStatement: SingleStatement =>
new SingleStatementExec(
sparkStatement.parsedPlan,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
index 63d937cb34820..7cf92db59067c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
@@ -14,159 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql.streaming
-import java.util.UUID
-import java.util.concurrent.TimeoutException
-
-import org.apache.spark.annotation.Evolving
-import org.apache.spark.sql.SparkSession
-
-/**
- * A handle to a query that is executing continuously in the background as new data arrives.
- * All these methods are thread-safe.
- * @since 2.0.0
- */
-@Evolving
-trait StreamingQuery {
-
- /**
- * Returns the user-specified name of the query, or null if not specified.
- * This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter`
- * as `dataframe.writeStream.queryName("query").start()`.
- * This name, if set, must be unique across all active queries.
- *
- * @since 2.0.0
- */
- def name: String
-
- /**
- * Returns the unique id of this query that persists across restarts from checkpoint data.
- * That is, this id is generated when a query is started for the first time, and
- * will be the same every time it is restarted from checkpoint data. Also see [[runId]].
- *
- * @since 2.1.0
- */
- def id: UUID
-
- /**
- * Returns the unique id of this run of the query. That is, every start/restart of a query will
- * generate a unique runId. Therefore, every time a query is restarted from
- * checkpoint, it will have the same [[id]] but different [[runId]]s.
- */
- def runId: UUID
-
- /**
- * Returns the `SparkSession` associated with `this`.
- *
- * @since 2.0.0
- */
- def sparkSession: SparkSession
-
- /**
- * Returns `true` if this query is actively running.
- *
- * @since 2.0.0
- */
- def isActive: Boolean
-
- /**
- * Returns the [[StreamingQueryException]] if the query was terminated by an exception.
- * @since 2.0.0
- */
- def exception: Option[StreamingQueryException]
-
- /**
- * Returns the current status of the query.
- *
- * @since 2.0.2
- */
- def status: StreamingQueryStatus
-
- /**
- * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query.
- * The number of progress updates retained for each stream is configured by Spark session
- * configuration `spark.sql.streaming.numRecentProgressUpdates`.
- *
- * @since 2.1.0
- */
- def recentProgress: Array[StreamingQueryProgress]
-
- /**
- * Returns the most recent [[StreamingQueryProgress]] update of this streaming query.
- *
- * @since 2.1.0
- */
- def lastProgress: StreamingQueryProgress
-
- /**
- * Waits for the termination of `this` query, either by `query.stop()` or by an exception.
- * If the query has terminated with an exception, then the exception will be thrown.
- *
- * If the query has terminated, then all subsequent calls to this method will either return
- * immediately (if the query was terminated by `stop()`), or throw the exception
- * immediately (if the query has terminated with exception).
- *
- * @throws StreamingQueryException if the query has terminated with an exception.
- *
- * @since 2.0.0
- */
- @throws[StreamingQueryException]
- def awaitTermination(): Unit
-
- /**
- * Waits for the termination of `this` query, either by `query.stop()` or by an exception.
- * If the query has terminated with an exception, then the exception will be thrown.
- * Otherwise, it returns whether the query has terminated or not within the `timeoutMs`
- * milliseconds.
- *
- * If the query has terminated, then all subsequent calls to this method will either return
- * `true` immediately (if the query was terminated by `stop()`), or throw the exception
- * immediately (if the query has terminated with exception).
- *
- * @throws StreamingQueryException if the query has terminated with an exception
- *
- * @since 2.0.0
- */
- @throws[StreamingQueryException]
- def awaitTermination(timeoutMs: Long): Boolean
-
- /**
- * Blocks until all available data in the source has been processed and committed to the sink.
- * This method is intended for testing. Note that in the case of continually arriving data, this
- * method may block forever. Additionally, this method is only guaranteed to block until data that
- * has been synchronously appended data to a `org.apache.spark.sql.execution.streaming.Source`
- * prior to invocation. (i.e. `getOffset` must immediately reflect the addition).
- * @since 2.0.0
- */
- def processAllAvailable(): Unit
-
- /**
- * Stops the execution of this query if it is running. This waits until the termination of the
- * query execution threads or until a timeout is hit.
- *
- * By default stop will block indefinitely. You can configure a timeout by the configuration
- * `spark.sql.streaming.stopTimeout`. A timeout of 0 (or negative) milliseconds will block
- * indefinitely. If a `TimeoutException` is thrown, users can retry stopping the stream. If the
- * issue persists, it is advisable to kill the Spark application.
- *
- * @since 2.0.0
- */
- @throws[TimeoutException]
- def stop(): Unit
-
- /**
- * Prints the physical plan to the console for debugging purposes.
- * @since 2.0.0
- */
- def explain(): Unit
+import org.apache.spark.sql.{api, SparkSession}
- /**
- * Prints the physical plan to the console for debugging purposes.
- *
- * @param extended whether to do extended explain or not
- * @since 2.0.0
- */
- def explain(extended: Boolean): Unit
+/** @inheritdoc */
+trait StreamingQuery extends api.StreamingQuery {
+ /** @inheritdoc */
+ override def sparkSession: SparkSession
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
deleted file mode 100644
index c1ceed048ae2c..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ /dev/null
@@ -1,276 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.streaming
-
-import java.util.UUID
-
-import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
-import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
-import org.json4s.{JObject, JString}
-import org.json4s.JsonAST.JValue
-import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
-import org.json4s.jackson.JsonMethods.{compact, render}
-
-import org.apache.spark.annotation.Evolving
-import org.apache.spark.scheduler.SparkListenerEvent
-
-/**
- * Interface for listening to events related to [[StreamingQuery StreamingQueries]].
- * @note The methods are not thread-safe as they may be called from different threads.
- *
- * @since 2.0.0
- */
-@Evolving
-abstract class StreamingQueryListener extends Serializable {
-
- import StreamingQueryListener._
-
- /**
- * Called when a query is started.
- * @note This is called synchronously with
- * [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]],
- * that is, `onQueryStart` will be called on all listeners before
- * `DataStreamWriter.start()` returns the corresponding [[StreamingQuery]]. Please
- * don't block this method as it will block your query.
- * @since 2.0.0
- */
- def onQueryStarted(event: QueryStartedEvent): Unit
-
- /**
- * Called when there is some status update (ingestion rate updated, etc.)
- *
- * @note This method is asynchronous. The status in [[StreamingQuery]] will always be
- * latest no matter when this method is called. Therefore, the status of [[StreamingQuery]]
- * may be changed before/when you process the event. E.g., you may find [[StreamingQuery]]
- * is terminated when you are processing `QueryProgressEvent`.
- * @since 2.0.0
- */
- def onQueryProgress(event: QueryProgressEvent): Unit
-
- /**
- * Called when the query is idle and waiting for new data to process.
- * @since 3.5.0
- */
- def onQueryIdle(event: QueryIdleEvent): Unit = {}
-
- /**
- * Called when a query is stopped, with or without error.
- * @since 2.0.0
- */
- def onQueryTerminated(event: QueryTerminatedEvent): Unit
-}
-
-/**
- * Py4J allows a pure interface so this proxy is required.
- */
-private[spark] trait PythonStreamingQueryListener {
- import StreamingQueryListener._
-
- def onQueryStarted(event: QueryStartedEvent): Unit
-
- def onQueryProgress(event: QueryProgressEvent): Unit
-
- def onQueryIdle(event: QueryIdleEvent): Unit
-
- def onQueryTerminated(event: QueryTerminatedEvent): Unit
-}
-
-private[spark] class PythonStreamingQueryListenerWrapper(
- listener: PythonStreamingQueryListener) extends StreamingQueryListener {
- import StreamingQueryListener._
-
- def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event)
-
- def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event)
-
- override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event)
-
- def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event)
-}
-
-/**
- * Companion object of [[StreamingQueryListener]] that defines the listener events.
- * @since 2.0.0
- */
-@Evolving
-object StreamingQueryListener extends Serializable {
-
- /**
- * Base type of [[StreamingQueryListener]] events
- * @since 2.0.0
- */
- @Evolving
- trait Event extends SparkListenerEvent
-
- /**
- * Event representing the start of a query
- * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`.
- * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
- * @param name User-specified name of the query, null if not specified.
- * @param timestamp The timestamp to start a query.
- * @since 2.1.0
- */
- @Evolving
- class QueryStartedEvent private[sql](
- val id: UUID,
- val runId: UUID,
- val name: String,
- val timestamp: String) extends Event with Serializable {
-
- def json: String = compact(render(jsonValue))
-
- private def jsonValue: JValue = {
- ("id" -> JString(id.toString)) ~
- ("runId" -> JString(runId.toString)) ~
- ("name" -> JString(name)) ~
- ("timestamp" -> JString(timestamp))
- }
- }
-
- private[spark] object QueryStartedEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
-
- private[spark] def jsonString(event: QueryStartedEvent): String =
- mapper.writeValueAsString(event)
-
- private[spark] def fromJson(json: String): QueryStartedEvent =
- mapper.readValue[QueryStartedEvent](json)
- }
-
- /**
- * Event representing any progress updates in a query.
- * @param progress The query progress updates.
- * @since 2.1.0
- */
- @Evolving
- class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event
- with Serializable {
-
- def json: String = compact(render(jsonValue))
-
- private def jsonValue: JValue = JObject("progress" -> progress.jsonValue)
- }
-
- private[spark] object QueryProgressEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
-
- private[spark] def jsonString(event: QueryProgressEvent): String =
- mapper.writeValueAsString(event)
-
- private[spark] def fromJson(json: String): QueryProgressEvent =
- mapper.readValue[QueryProgressEvent](json)
- }
-
- /**
- * Event representing that query is idle and waiting for new data to process.
- *
- * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`.
- * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
- * @param timestamp The timestamp when the latest no-batch trigger happened.
- * @since 3.5.0
- */
- @Evolving
- class QueryIdleEvent private[sql](
- val id: UUID,
- val runId: UUID,
- val timestamp: String) extends Event with Serializable {
-
- def json: String = compact(render(jsonValue))
-
- private def jsonValue: JValue = {
- ("id" -> JString(id.toString)) ~
- ("runId" -> JString(runId.toString)) ~
- ("timestamp" -> JString(timestamp))
- }
- }
-
- private[spark] object QueryIdleEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
-
- private[spark] def jsonString(event: QueryTerminatedEvent): String =
- mapper.writeValueAsString(event)
-
- private[spark] def fromJson(json: String): QueryTerminatedEvent =
- mapper.readValue[QueryTerminatedEvent](json)
- }
-
- /**
- * Event representing that termination of a query.
- *
- * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`.
- * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
- * @param exception The exception message of the query if the query was terminated
- * with an exception. Otherwise, it will be `None`.
- * @param errorClassOnException The error class from the exception if the query was terminated
- * with an exception which is a part of error class framework.
- * If the query was terminated without an exception, or the
- * exception is not a part of error class framework, it will be
- * `None`.
- * @since 2.1.0
- */
- @Evolving
- class QueryTerminatedEvent private[sql](
- val id: UUID,
- val runId: UUID,
- val exception: Option[String],
- val errorClassOnException: Option[String]) extends Event with Serializable {
- // compatibility with versions in prior to 3.5.0
- def this(id: UUID, runId: UUID, exception: Option[String]) = {
- this(id, runId, exception, None)
- }
-
- def json: String = compact(render(jsonValue))
-
- private def jsonValue: JValue = {
- ("id" -> JString(id.toString)) ~
- ("runId" -> JString(runId.toString)) ~
- ("exception" -> JString(exception.orNull)) ~
- ("errorClassOnException" -> JString(errorClassOnException.orNull))
- }
- }
-
- private[spark] object QueryTerminatedEvent {
- private val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
-
- private[spark] def jsonString(event: QueryTerminatedEvent): String =
- mapper.writeValueAsString(event)
-
- private[spark] def fromJson(json: String): QueryTerminatedEvent =
- mapper.readValue[QueryTerminatedEvent](json)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 55d2e639a56b1..3ab6d02f6b515 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -364,7 +364,7 @@ class StreamingQueryManager private[sql] (
.orElse(activeQueries.get(query.id)) // shouldn't be needed but paranoia ...
val shouldStopActiveRun =
- sparkSession.conf.get(SQLConf.STREAMING_STOP_ACTIVE_RUN_ON_RESTART)
+ sparkSession.sessionState.conf.getConf(SQLConf.STREAMING_STOP_ACTIVE_RUN_ON_RESTART)
if (activeOption.isDefined) {
if (shouldStopActiveRun) {
val oldQuery = activeOption.get
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
deleted file mode 100644
index fe187917ec021..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.streaming
-
-import org.json4s._
-import org.json4s.JsonAST.JValue
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods._
-
-import org.apache.spark.annotation.Evolving
-
-/**
- * Reports information about the instantaneous status of a streaming query.
- *
- * @param message A human readable description of what the stream is currently doing.
- * @param isDataAvailable True when there is new data to be processed. Doesn't apply
- * to ContinuousExecution where it is always false.
- * @param isTriggerActive True when the trigger is actively firing, false when waiting for the
- * next trigger time. Doesn't apply to ContinuousExecution where it is
- * always false.
- *
- * @since 2.1.0
- */
-@Evolving
-class StreamingQueryStatus protected[sql](
- val message: String,
- val isDataAvailable: Boolean,
- val isTriggerActive: Boolean) extends Serializable {
-
- /** The compact JSON representation of this status. */
- def json: String = compact(render(jsonValue))
-
- /** The pretty (i.e. indented) JSON representation of this status. */
- def prettyJson: String = pretty(render(jsonValue))
-
- override def toString: String = prettyJson
-
- private[sql] def copy(
- message: String = this.message,
- isDataAvailable: Boolean = this.isDataAvailable,
- isTriggerActive: Boolean = this.isTriggerActive): StreamingQueryStatus = {
- new StreamingQueryStatus(
- message = message,
- isDataAvailable = isDataAvailable,
- isTriggerActive = isTriggerActive)
- }
-
- private[sql] def jsonValue: JValue = {
- ("message" -> JString(message)) ~
- ("isDataAvailable" -> JBool(isDataAvailable)) ~
- ("isTriggerActive" -> JBool(isTriggerActive))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
deleted file mode 100644
index 05323d9d03811..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ /dev/null
@@ -1,301 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.streaming
-
-import java.{util => ju}
-import java.lang.{Long => JLong}
-import java.util.UUID
-
-import scala.jdk.CollectionConverters._
-import scala.util.control.NonFatal
-
-import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
-import com.fasterxml.jackson.databind.annotation.JsonDeserialize
-import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
-import org.json4s._
-import org.json4s.JsonAST.JValue
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods._
-
-import org.apache.spark.annotation.Evolving
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.streaming.SafeJsonSerializer.{safeDoubleToJValue, safeMapToJValue}
-import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS
-
-/**
- * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger.
- */
-@Evolving
-class StateOperatorProgress private[spark](
- val operatorName: String,
- val numRowsTotal: Long,
- val numRowsUpdated: Long,
- val allUpdatesTimeMs: Long,
- val numRowsRemoved: Long,
- val allRemovalsTimeMs: Long,
- val commitTimeMs: Long,
- val memoryUsedBytes: Long,
- val numRowsDroppedByWatermark: Long,
- val numShufflePartitions: Long,
- val numStateStoreInstances: Long,
- val customMetrics: ju.Map[String, JLong] = new ju.HashMap()
- ) extends Serializable {
-
- /** The compact JSON representation of this progress. */
- def json: String = compact(render(jsonValue))
-
- /** The pretty (i.e. indented) JSON representation of this progress. */
- def prettyJson: String = pretty(render(jsonValue))
-
- private[sql] def copy(
- newNumRowsUpdated: Long,
- newNumRowsDroppedByWatermark: Long): StateOperatorProgress =
- new StateOperatorProgress(
- operatorName = operatorName, numRowsTotal = numRowsTotal, numRowsUpdated = newNumRowsUpdated,
- allUpdatesTimeMs = allUpdatesTimeMs, numRowsRemoved = numRowsRemoved,
- allRemovalsTimeMs = allRemovalsTimeMs, commitTimeMs = commitTimeMs,
- memoryUsedBytes = memoryUsedBytes, numRowsDroppedByWatermark = newNumRowsDroppedByWatermark,
- numShufflePartitions = numShufflePartitions, numStateStoreInstances = numStateStoreInstances,
- customMetrics = customMetrics)
-
- private[sql] def jsonValue: JValue = {
- ("operatorName" -> JString(operatorName)) ~
- ("numRowsTotal" -> JInt(numRowsTotal)) ~
- ("numRowsUpdated" -> JInt(numRowsUpdated)) ~
- ("allUpdatesTimeMs" -> JInt(allUpdatesTimeMs)) ~
- ("numRowsRemoved" -> JInt(numRowsRemoved)) ~
- ("allRemovalsTimeMs" -> JInt(allRemovalsTimeMs)) ~
- ("commitTimeMs" -> JInt(commitTimeMs)) ~
- ("memoryUsedBytes" -> JInt(memoryUsedBytes)) ~
- ("numRowsDroppedByWatermark" -> JInt(numRowsDroppedByWatermark)) ~
- ("numShufflePartitions" -> JInt(numShufflePartitions)) ~
- ("numStateStoreInstances" -> JInt(numStateStoreInstances)) ~
- ("customMetrics" -> {
- if (!customMetrics.isEmpty) {
- val keys = customMetrics.keySet.asScala.toSeq.sorted
- keys.map { k => k -> JInt(customMetrics.get(k).toLong) : JObject }.reduce(_ ~ _)
- } else {
- JNothing
- }
- })
- }
-
- override def toString: String = prettyJson
-}
-
-/**
- * Information about progress made in the execution of a [[StreamingQuery]] during
- * a trigger. Each event relates to processing done for a single trigger of the streaming
- * query. Events are emitted even when no new data is available to be processed.
- *
- * @param id A unique query id that persists across restarts. See `StreamingQuery.id()`.
- * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`.
- * @param name User-specified name of the query, null if not specified.
- * @param timestamp Beginning time of the trigger in ISO8601 format, i.e. UTC timestamps.
- * @param batchId A unique id for the current batch of data being processed. Note that in the
- * case of retries after a failure a given batchId my be executed more than once.
- * Similarly, when there is no data to be processed, the batchId will not be
- * incremented.
- * @param batchDuration The process duration of each batch.
- * @param durationMs The amount of time taken to perform various operations in milliseconds.
- * @param eventTime Statistics of event time seen in this batch. It may contain the following keys:
- * {{{
- * "max" -> "2016-12-05T20:54:20.827Z" // maximum event time seen in this trigger
- * "min" -> "2016-12-05T20:54:20.827Z" // minimum event time seen in this trigger
- * "avg" -> "2016-12-05T20:54:20.827Z" // average event time seen in this trigger
- * "watermark" -> "2016-12-05T20:54:20.827Z" // watermark used in this trigger
- * }}}
- * All timestamps are in ISO8601 format, i.e. UTC timestamps.
- * @param stateOperators Information about operators in the query that store state.
- * @param sources detailed statistics on data being read from each of the streaming sources.
- * @since 2.1.0
- */
-@Evolving
-class StreamingQueryProgress private[spark](
- val id: UUID,
- val runId: UUID,
- val name: String,
- val timestamp: String,
- val batchId: Long,
- val batchDuration: Long,
- val durationMs: ju.Map[String, JLong],
- val eventTime: ju.Map[String, String],
- val stateOperators: Array[StateOperatorProgress],
- val sources: Array[SourceProgress],
- val sink: SinkProgress,
- @JsonDeserialize(contentAs = classOf[GenericRowWithSchema])
- val observedMetrics: ju.Map[String, Row]) extends Serializable {
-
- /** The aggregate (across all sources) number of records processed in a trigger. */
- def numInputRows: Long = sources.map(_.numInputRows).sum
-
- /** The aggregate (across all sources) rate of data arriving. */
- def inputRowsPerSecond: Double = sources.map(_.inputRowsPerSecond).sum
-
- /** The aggregate (across all sources) rate at which Spark is processing data. */
- def processedRowsPerSecond: Double = sources.map(_.processedRowsPerSecond).sum
-
- /** The compact JSON representation of this progress. */
- def json: String = compact(render(jsonValue))
-
- /** The pretty (i.e. indented) JSON representation of this progress. */
- def prettyJson: String = pretty(render(jsonValue))
-
- override def toString: String = prettyJson
-
- private[sql] def jsonValue: JValue = {
- ("id" -> JString(id.toString)) ~
- ("runId" -> JString(runId.toString)) ~
- ("name" -> JString(name)) ~
- ("timestamp" -> JString(timestamp)) ~
- ("batchId" -> JInt(batchId)) ~
- ("batchDuration" -> JInt(batchDuration)) ~
- ("numInputRows" -> JInt(numInputRows)) ~
- ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~
- ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~
- ("durationMs" -> safeMapToJValue[JLong](durationMs, v => JInt(v.toLong))) ~
- ("eventTime" -> safeMapToJValue[String](eventTime, s => JString(s))) ~
- ("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~
- ("sources" -> JArray(sources.map(_.jsonValue).toList)) ~
- ("sink" -> sink.jsonValue) ~
- ("observedMetrics" -> safeMapToJValue[Row](observedMetrics, row => row.jsonValue))
- }
-}
-
-private[spark] object StreamingQueryProgress {
- private[this] val mapper = {
- val ret = new ObjectMapper() with ClassTagExtensions
- ret.registerModule(DefaultScalaModule)
- ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
- ret
- }
-
- private[spark] def jsonString(progress: StreamingQueryProgress): String =
- mapper.writeValueAsString(progress)
-
- private[spark] def fromJson(json: String): StreamingQueryProgress =
- mapper.readValue[StreamingQueryProgress](json)
-}
-
-/**
- * Information about progress made for a source in the execution of a [[StreamingQuery]]
- * during a trigger. See [[StreamingQueryProgress]] for more information.
- *
- * @param description Description of the source.
- * @param startOffset The starting offset for data being read.
- * @param endOffset The ending offset for data being read.
- * @param latestOffset The latest offset from this source.
- * @param numInputRows The number of records read from this source.
- * @param inputRowsPerSecond The rate at which data is arriving from this source.
- * @param processedRowsPerSecond The rate at which data from this source is being processed by
- * Spark.
- * @since 2.1.0
- */
-@Evolving
-class SourceProgress protected[spark](
- val description: String,
- val startOffset: String,
- val endOffset: String,
- val latestOffset: String,
- val numInputRows: Long,
- val inputRowsPerSecond: Double,
- val processedRowsPerSecond: Double,
- val metrics: ju.Map[String, String] = Map[String, String]().asJava) extends Serializable {
-
- /** The compact JSON representation of this progress. */
- def json: String = compact(render(jsonValue))
-
- /** The pretty (i.e. indented) JSON representation of this progress. */
- def prettyJson: String = pretty(render(jsonValue))
-
- override def toString: String = prettyJson
-
- private[sql] def jsonValue: JValue = {
- ("description" -> JString(description)) ~
- ("startOffset" -> tryParse(startOffset)) ~
- ("endOffset" -> tryParse(endOffset)) ~
- ("latestOffset" -> tryParse(latestOffset)) ~
- ("numInputRows" -> JInt(numInputRows)) ~
- ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~
- ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~
- ("metrics" -> safeMapToJValue[String](metrics, s => JString(s)))
- }
-
- private def tryParse(json: String) = try {
- parse(json)
- } catch {
- case NonFatal(e) => JString(json)
- }
-}
-
-/**
- * Information about progress made for a sink in the execution of a [[StreamingQuery]]
- * during a trigger. See [[StreamingQueryProgress]] for more information.
- *
- * @param description Description of the source corresponding to this status.
- * @param numOutputRows Number of rows written to the sink or -1 for Continuous Mode (temporarily)
- * or Sink V1 (until decommissioned).
- * @since 2.1.0
- */
-@Evolving
-class SinkProgress protected[spark](
- val description: String,
- val numOutputRows: Long,
- val metrics: ju.Map[String, String] = Map[String, String]().asJava) extends Serializable {
-
- /** SinkProgress without custom metrics. */
- protected[sql] def this(description: String) = {
- this(description, DEFAULT_NUM_OUTPUT_ROWS)
- }
-
- /** The compact JSON representation of this progress. */
- def json: String = compact(render(jsonValue))
-
- /** The pretty (i.e. indented) JSON representation of this progress. */
- def prettyJson: String = pretty(render(jsonValue))
-
- override def toString: String = prettyJson
-
- private[sql] def jsonValue: JValue = {
- ("description" -> JString(description)) ~
- ("numOutputRows" -> JInt(numOutputRows)) ~
- ("metrics" -> safeMapToJValue[String](metrics, s => JString(s)))
- }
-}
-
-private[sql] object SinkProgress {
- val DEFAULT_NUM_OUTPUT_ROWS: Long = -1L
-
- def apply(description: String, numOutputRows: Option[Long],
- metrics: ju.Map[String, String] = Map[String, String]().asJava): SinkProgress =
- new SinkProgress(description, numOutputRows.getOrElse(DEFAULT_NUM_OUTPUT_ROWS), metrics)
-}
-
-private object SafeJsonSerializer {
- def safeDoubleToJValue(value: Double): JValue = {
- if (value.isNaN || value.isInfinity) JNothing else JDouble(value)
- }
-
- /** Convert map to JValue while handling empty maps. Also, this sorts the keys. */
- def safeMapToJValue[T](map: ju.Map[String, T], valueToJValue: T => JValue): JValue = {
- if (map == null || map.isEmpty) return JNothing
- val keys = map.asScala.keySet.toSeq.sorted
- keys.map { k => k -> valueToJValue(map.get(k)) : JObject }.reduce(_ ~ _)
- }
-}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 316e5e9676723..5ad1380e1fb82 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -265,6 +265,7 @@
| org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct |
| org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct |
| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct |
+| org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) AS result | struct |
| org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct |
| org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct |
| org.apache.spark.sql.catalyst.expressions.RegExpCount | regexp_count | SELECT regexp_count('Steven Jones and Stephen Smith are the best players', 'Ste(v|ph)en') | struct |
@@ -367,6 +368,7 @@
| org.apache.spark.sql.catalyst.expressions.UnaryMinus | negative | SELECT negative(1) | struct |
| org.apache.spark.sql.catalyst.expressions.UnaryPositive | positive | SELECT positive(1) | struct<(+ 1):int> |
| org.apache.spark.sql.catalyst.expressions.Unhex | unhex | SELECT decode(unhex('537061726B2053514C'), 'UTF-8') | struct |
+| org.apache.spark.sql.catalyst.expressions.Uniform | uniform | SELECT uniform(10, 20, 0) > 0 AS result | struct |
| org.apache.spark.sql.catalyst.expressions.UnixDate | unix_date | SELECT unix_date(DATE("1970-01-02")) | struct |
| org.apache.spark.sql.catalyst.expressions.UnixMicros | unix_micros | SELECT unix_micros(TIMESTAMP('1970-01-01 00:00:01Z')) | struct |
| org.apache.spark.sql.catalyst.expressions.UnixMillis | unix_millis | SELECT unix_millis(TIMESTAMP('1970-01-01 00:00:01Z')) | struct |
@@ -451,6 +453,7 @@
| org.apache.spark.sql.catalyst.expressions.variant.ParseJsonExpressionBuilder | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct |
| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct |
| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariantAgg | schema_of_variant_agg | SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j) | struct |
+| org.apache.spark.sql.catalyst.expressions.variant.ToVariantObject | to_variant_object | SELECT to_variant_object(named_struct('a', 1, 'b', 2)) | struct |
| org.apache.spark.sql.catalyst.expressions.variant.TryParseJsonExpressionBuilder | try_parse_json | SELECT try_parse_json('{"a":1,"b":0.8}') | struct |
| org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct |
| org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct |
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out
index 57108c4582f45..53595d1b8a3eb 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/array.sql.out
@@ -194,25 +194,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
-- !query
select sort_array(array('b', 'd'), cast(NULL as boolean))
-- !query analysis
-org.apache.spark.sql.catalyst.ExtendedAnalysisException
-{
- "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
- "sqlState" : "42K09",
- "messageParameters" : {
- "inputSql" : "\"CAST(NULL AS BOOLEAN)\"",
- "inputType" : "\"BOOLEAN\"",
- "paramIndex" : "second",
- "requiredType" : "\"BOOLEAN\"",
- "sqlExpr" : "\"sort_array(array(b, d), CAST(NULL AS BOOLEAN))\""
- },
- "queryContext" : [ {
- "objectType" : "",
- "objectName" : "",
- "startIndex" : 8,
- "stopIndex" : 57,
- "fragment" : "sort_array(array('b', 'd'), cast(NULL as boolean))"
- } ]
-}
+Project [sort_array(array(b, d), cast(null as boolean)) AS sort_array(array(b, d), CAST(NULL AS BOOLEAN))#x]
++- OneRowRelation
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out
index fd927b99c6456..0e4d2d4e99e26 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out
@@ -736,7 +736,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles)
-- !query
select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy'))
-- !query analysis
-Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x]
+Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x]
+- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out
index 12756576ded9b..b0d128c4cab69 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out
@@ -981,9 +981,13 @@ select interval '20 15:40:32.99899999' day to hour
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "20 15:40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`",
+ "typeName" : "interval day to hour"
},
"queryContext" : [ {
"objectType" : "",
@@ -1000,9 +1004,13 @@ select interval '20 15:40:32.99899999' day to minute
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "20 15:40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`",
+ "typeName" : "interval day to minute"
},
"queryContext" : [ {
"objectType" : "",
@@ -1019,9 +1027,13 @@ select interval '15:40:32.99899999' hour to minute
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "15:40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`",
+ "typeName" : "interval hour to minute"
},
"queryContext" : [ {
"objectType" : "",
@@ -1038,9 +1050,13 @@ select interval '15:40.99899999' hour to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "15:40.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`",
+ "typeName" : "interval hour to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -1057,9 +1073,13 @@ select interval '15:40' hour to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "15:40",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`",
+ "typeName" : "interval hour to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -1076,9 +1096,13 @@ select interval '20 40:32.99899999' minute to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "20 40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`",
+ "typeName" : "interval minute to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -1460,9 +1484,11 @@ SELECT INTERVAL '178956970-8' YEAR TO MONTH
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Error parsing interval year-month string: integer overflow"
+ "input" : "178956970-8",
+ "interval" : "year-month"
},
"queryContext" : [ {
"objectType" : "",
@@ -1909,9 +1935,13 @@ select interval '-\t2-2\t' year to month
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t"
+ "input" : "-\t2-2\t",
+ "intervalStr" : "year-month",
+ "supportedFormat" : "`[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`",
+ "typeName" : "interval year to month"
},
"queryContext" : [ {
"objectType" : "",
@@ -1935,9 +1965,13 @@ select interval '\n-\t10\t 12:34:46.789\t' day to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: \n-\t10\t 12:34:46.789\t, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "\n-\t10\t 12:34:46.789\t",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`",
+ "typeName" : "interval day to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -2074,7 +2108,7 @@ SELECT
to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)),
from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month')
-- !query analysis
-Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x]
+Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x]
+- OneRowRelation
@@ -2085,7 +2119,7 @@ SELECT
to_json(map('a', interval 100 day 130 minute)),
from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute')
-- !query analysis
-Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x]
+Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x]
+- OneRowRelation
@@ -2096,7 +2130,7 @@ SELECT
to_json(map('a', interval 32 year 10 month)),
from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month')
-- !query analysis
-Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x]
+Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x]
+- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out
index 45fc3bd03a782..ae8e47ed3665c 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out
@@ -16,12 +16,12 @@ Project [from_csv(StructField(cube,IntegerType,true), 1, Some(America/Los_Angele
-- !query
select from_json('{"create":1}', 'create INT')
-- !query analysis
-Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles)) AS from_json({"create":1})#x]
+Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles), false) AS from_json({"create":1})#x]
+- OneRowRelation
-- !query
select from_json('{"cube":1}', 'cube INT')
-- !query analysis
-Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles)) AS from_json({"cube":1})#x]
+Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles), false) AS from_json({"cube":1})#x]
+- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out
index bf34490d657e3..560974d28c545 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out
@@ -730,7 +730,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An
-- !query
select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy'))
-- !query analysis
-Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x]
+Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x]
+- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out
index fb331089d7545..4db56d6c70561 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out
@@ -194,25 +194,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
-- !query
select sort_array(array('b', 'd'), cast(NULL as boolean))
-- !query analysis
-org.apache.spark.sql.catalyst.ExtendedAnalysisException
-{
- "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
- "sqlState" : "42K09",
- "messageParameters" : {
- "inputSql" : "\"CAST(NULL AS BOOLEAN)\"",
- "inputType" : "\"BOOLEAN\"",
- "paramIndex" : "second",
- "requiredType" : "\"BOOLEAN\"",
- "sqlExpr" : "\"sort_array(array(b, d), CAST(NULL AS BOOLEAN))\""
- },
- "queryContext" : [ {
- "objectType" : "",
- "objectName" : "",
- "startIndex" : 8,
- "stopIndex" : 57,
- "fragment" : "sort_array(array('b', 'd'), cast(NULL as boolean))"
- } ]
-}
+Project [sort_array(array(b, d), cast(null as boolean)) AS sort_array(array(b, d), CAST(NULL AS BOOLEAN))#x]
++- OneRowRelation
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
index 14ac67eb93a32..eed7fa73ab698 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out
@@ -436,6 +436,30 @@ Project [str_to_map(collate(text#x, utf8_binary), collate(pairDelim#x, utf8_bina
+- Relation spark_catalog.default.t4[text#x,pairDelim#x,keyValueDelim#x] parquet
+-- !query
+select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(text, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 106,
+ "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)"
+ } ]
+}
+
+
-- !query
drop table t4
-- !query analysis
@@ -444,21 +468,227 @@ DropTable false, false
-- !query
-create table t5(str string collate utf8_binary, delimiter string collate utf8_lcase, partNum int) using parquet
+create table t5(s string, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet
-- !query analysis
CreateDataSourceTableCommand `spark_catalog`.`default`.`t5`, false
-- !query
-insert into t5 values('11AB12AB13', 'AB', 2)
+insert into t5 values ('Spark', 'Spark', 'SQL')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaAAaA')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaA')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('aaAaAAaA', 'aaAaAAaA', 'aaAaaAaAaaAaaAaAaaAaaAaA')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('bbAbaAbA', 'bbAbAAbA', 'a')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('İo', 'İo', 'İo')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('İo', 'İo', 'i̇o')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('efd2', 'efd2', 'efd2')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('Hello, world! Nice day.', 'Hello, world! Nice day.', 'Hello, world! Nice day.')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('Something else. Nothing here.', 'Something else. Nothing here.', 'Something else. Nothing here.')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('kitten', 'kitten', 'sitTing')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('abc', 'abc', 'abc')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+insert into t5 values ('abcdcba', 'abcdcba', 'aBcDCbA')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [s, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS s#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x, col3#x]
+
+
+-- !query
+create table t6(ascii long) using parquet
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t6`, false
+
+
+-- !query
+insert into t6 values (97)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t6, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t6], Append, `spark_catalog`.`default`.`t6`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t6), [ascii]
++- Project [cast(col1#x as bigint) AS ascii#xL]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+insert into t6 values (66)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t6, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t6], Append, `spark_catalog`.`default`.`t6`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t6), [ascii]
++- Project [cast(col1#x as bigint) AS ascii#xL]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+create table t7(ascii double) using parquet
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t7`, false
+
+
+-- !query
+insert into t7 values (97.52143)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t7, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t7], Append, `spark_catalog`.`default`.`t7`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t7), [ascii]
++- Project [cast(col1#x as double) AS ascii#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+insert into t7 values (66.421)
-- !query analysis
-InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t5, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t5], Append, `spark_catalog`.`default`.`t5`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t5), [str, delimiter, partNum]
-+- Project [cast(col1#x as string) AS str#x, cast(col2#x as string collate UTF8_LCASE) AS delimiter#x, cast(col3#x as int) AS partNum#x]
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t7, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t7], Append, `spark_catalog`.`default`.`t7`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t7), [ascii]
++- Project [cast(col1#x as double) AS ascii#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+create table t8(format string collate utf8_binary, utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t8`, false
+
+
+-- !query
+insert into t8 values ('%s%s', 'abCdE', 'abCdE')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t8, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t8], Append, `spark_catalog`.`default`.`t8`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t8), [format, utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS format#x, cast(col2#x as string) AS utf8_binary#x, cast(col3#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+- LocalRelation [col1#x, col2#x, col3#x]
-- !query
-select split_part(str, delimiter, partNum) from t5
+create table t9(num long) using parquet
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t9`, false
+
+
+-- !query
+insert into t9 values (97)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t9, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t9], Append, `spark_catalog`.`default`.`t9`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t9), [num]
++- Project [cast(col1#x as bigint) AS num#xL]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+insert into t9 values (66)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t9, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t9], Append, `spark_catalog`.`default`.`t9`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t9), [num]
++- Project [cast(col1#x as bigint) AS num#xL]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+create table t10(utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase) using parquet
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t10`, false
+
+
+-- !query
+insert into t10 values ('aaAaAAaA', 'aaAaaAaA')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t10, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t10], Append, `spark_catalog`.`default`.`t10`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t10), [utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS utf8_binary#x, cast(col2#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+insert into t10 values ('efd2', 'efd2')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t10, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t10], Append, `spark_catalog`.`default`.`t10`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t10), [utf8_binary, utf8_lcase]
++- Project [cast(col1#x as string) AS utf8_binary#x, cast(col2#x as string collate UTF8_LCASE) AS utf8_lcase#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+select concat_ws(' ', utf8_lcase, utf8_lcase) from t5
+-- !query analysis
+Project [concat_ws(cast( as string collate UTF8_LCASE), utf8_lcase#x, utf8_lcase#x) AS concat_ws( , utf8_lcase, utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select concat_ws(' ', utf8_binary, utf8_lcase) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -468,7 +698,7 @@ org.apache.spark.sql.AnalysisException
-- !query
-select split_part(str collate utf8_binary, delimiter collate utf8_lcase, partNum) from t5
+select concat_ws(' ' collate utf8_binary, utf8_binary, 'SQL' collate utf8_lcase) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -481,36 +711,39 @@ org.apache.spark.sql.AnalysisException
-- !query
-select split_part(str collate utf8_binary, delimiter collate utf8_binary, partNum) from t5
+select concat_ws(' ' collate utf8_lcase, utf8_binary, 'SQL' collate utf8_lcase) from t5
-- !query analysis
-Project [split_part(collate(str#x, utf8_binary), collate(delimiter#x, utf8_binary), partNum#x) AS split_part(collate(str, utf8_binary), collate(delimiter, utf8_binary), partNum)#x]
+Project [concat_ws(collate( , utf8_lcase), cast(utf8_binary#x as string collate UTF8_LCASE), collate(SQL, utf8_lcase)) AS concat_ws(collate( , utf8_lcase), utf8_binary, collate(SQL, utf8_lcase))#x]
+- SubqueryAlias spark_catalog.default.t5
- +- Relation spark_catalog.default.t5[str#x,delimiter#x,partNum#x] parquet
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
-- !query
-drop table t5
+select concat_ws(',', utf8_lcase, 'word'), concat_ws(',', utf8_binary, 'word') from t5
-- !query analysis
-DropTable false, false
-+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t5
+Project [concat_ws(cast(, as string collate UTF8_LCASE), utf8_lcase#x, cast(word as string collate UTF8_LCASE)) AS concat_ws(,, utf8_lcase, word)#x, concat_ws(,, utf8_binary#x, word) AS concat_ws(,, utf8_binary, word)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
-- !query
-create table t6 (utf8_binary string collate utf8_binary, utf8_lcase string collate utf8_lcase, threshold int) using parquet
+select concat_ws(',', utf8_lcase, 'word' collate utf8_binary), concat_ws(',', utf8_binary, 'word' collate utf8_lcase) from t5
-- !query analysis
-CreateDataSourceTableCommand `spark_catalog`.`default`.`t6`, false
+Project [concat_ws(,, cast(utf8_lcase#x as string), collate(word, utf8_binary)) AS concat_ws(,, utf8_lcase, collate(word, utf8_binary))#x, concat_ws(cast(, as string collate UTF8_LCASE), cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase)) AS concat_ws(,, utf8_binary, collate(word, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
-- !query
-insert into t6 values('kitten', 'sitting', 2)
+select elt(2, s, utf8_binary) from t5
-- !query analysis
-InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t6, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t6], Append, `spark_catalog`.`default`.`t6`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t6), [utf8_binary, utf8_lcase, threshold]
-+- Project [cast(col1#x as string) AS utf8_binary#x, cast(col2#x as string collate UTF8_LCASE) AS utf8_lcase#x, cast(col3#x as int) AS threshold#x]
- +- LocalRelation [col1#x, col2#x, col3#x]
+Project [elt(2, s#x, utf8_binary#x, false) AS elt(2, s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
-- !query
-select levenshtein(utf8_binary, utf8_lcase) from t6
+select elt(2, utf8_binary, utf8_lcase, s) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -520,7 +753,7 @@ org.apache.spark.sql.AnalysisException
-- !query
-select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t6
+select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -533,15 +766,39 @@ org.apache.spark.sql.AnalysisException
-- !query
-select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t6
+select elt(1, utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary) from t5
-- !query analysis
-Project [levenshtein(collate(utf8_binary#x, utf8_binary), collate(utf8_lcase#x, utf8_binary), None) AS levenshtein(collate(utf8_binary, utf8_binary), collate(utf8_lcase, utf8_binary))#x]
-+- SubqueryAlias spark_catalog.default.t6
- +- Relation spark_catalog.default.t6[utf8_binary#x,utf8_lcase#x,threshold#x] parquet
+Project [elt(1, collate(utf8_binary#x, utf8_binary), collate(utf8_lcase#x, utf8_binary), false) AS elt(1, collate(utf8_binary, utf8_binary), collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select elt(1, utf8_binary collate utf8_binary, utf8_lcase) from t5
+-- !query analysis
+Project [elt(1, collate(utf8_binary#x, utf8_binary), cast(utf8_lcase#x as string), false) AS elt(1, collate(utf8_binary, utf8_binary), utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select elt(1, utf8_binary, 'word'), elt(1, utf8_lcase, 'word') from t5
+-- !query analysis
+Project [elt(1, utf8_binary#x, word, false) AS elt(1, utf8_binary, word)#x, elt(1, utf8_lcase#x, cast(word as string collate UTF8_LCASE), false) AS elt(1, utf8_lcase, word)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
-- !query
-select levenshtein(utf8_binary, utf8_lcase, threshold) from t6
+select elt(1, utf8_binary, 'word' collate utf8_lcase), elt(1, utf8_lcase, 'word' collate utf8_binary) from t5
+-- !query analysis
+Project [elt(1, cast(utf8_binary#x as string collate UTF8_LCASE), collate(word, utf8_lcase), false) AS elt(1, utf8_binary, collate(word, utf8_lcase))#x, elt(1, cast(utf8_lcase#x as string), collate(word, utf8_binary), false) AS elt(1, utf8_lcase, collate(word, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select split_part(utf8_binary, utf8_lcase, 3) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -551,7 +808,15 @@ org.apache.spark.sql.AnalysisException
-- !query
-select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_lcase, threshold) from t6
+select split_part(s, utf8_binary, 1) from t5
+-- !query analysis
+Project [split_part(s#x, utf8_binary#x, 1) AS split_part(s, utf8_binary, 1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select split_part(utf8_binary collate utf8_binary, s collate utf8_lcase, 1) from t5
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -564,15 +829,1857 @@ org.apache.spark.sql.AnalysisException
-- !query
-select levenshtein(utf8_binary collate utf8_binary, utf8_lcase collate utf8_binary, threshold) from t6
+select split_part(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5
-- !query analysis
-Project [levenshtein(collate(utf8_binary#x, utf8_binary), collate(utf8_lcase#x, utf8_binary), Some(threshold#x)) AS levenshtein(collate(utf8_binary, utf8_binary), collate(utf8_lcase, utf8_binary), threshold)#x]
-+- SubqueryAlias spark_catalog.default.t6
- +- Relation spark_catalog.default.t6[utf8_binary#x,utf8_lcase#x,threshold#x] parquet
+Project [split_part(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 2) AS split_part(utf8_binary, collate(utf8_lcase, utf8_binary), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
-- !query
-drop table t6
+select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5
-- !query analysis
-DropTable false, false
-+- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t6
+Project [split_part(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 2) AS split_part(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 83,
+ "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)"
+ } ]
+}
+
+
+-- !query
+select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5
+-- !query analysis
+Project [split_part(utf8_binary#x, a, 3) AS split_part(utf8_binary, a, 3)#x, split_part(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 3) AS split_part(utf8_lcase, a, 3)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5
+-- !query analysis
+Project [split_part(cast(utf8_binary#x as string collate UTF8_LCASE), collate(a, utf8_lcase), 3) AS split_part(utf8_binary, collate(a, utf8_lcase), 3)#x, split_part(cast(utf8_lcase#x as string), collate(a, utf8_binary), 3) AS split_part(utf8_lcase, collate(a, utf8_binary), 3)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select contains(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select contains(s, utf8_binary) from t5
+-- !query analysis
+Project [Contains(s#x, utf8_binary#x) AS contains(s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select contains(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select contains(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [Contains(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS contains(utf8_binary, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [Contains(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS contains(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 78,
+ "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5
+-- !query analysis
+Project [Contains(utf8_binary#x, a) AS contains(utf8_binary, a)#x, Contains(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS contains(utf8_lcase, a)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5
+-- !query analysis
+Project [Contains(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase)) AS contains(utf8_binary, collate(AaAA, utf8_lcase))#x, Contains(cast(utf8_lcase#x as string), collate(AAa, utf8_binary)) AS contains(utf8_lcase, collate(AAa, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select substring_index(utf8_binary, utf8_lcase, 2) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select substring_index(s, utf8_binary,1) from t5
+-- !query analysis
+Project [substring_index(s#x, utf8_binary#x, 1) AS substring_index(s, utf8_binary, 1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select substring_index(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select substring_index(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5
+-- !query analysis
+Project [substring_index(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 2) AS substring_index(utf8_binary, collate(utf8_lcase, utf8_binary), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5
+-- !query analysis
+Project [substring_index(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 2) AS substring_index(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 88,
+ "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)"
+ } ]
+}
+
+
+-- !query
+select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5
+-- !query analysis
+Project [substring_index(utf8_binary#x, a, 2) AS substring_index(utf8_binary, a, 2)#x, substring_index(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2) AS substring_index(utf8_lcase, a, 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5
+-- !query analysis
+Project [substring_index(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), 2) AS substring_index(utf8_binary, collate(AaAA, utf8_lcase), 2)#x, substring_index(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), 2) AS substring_index(utf8_lcase, collate(AAa, utf8_binary), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select instr(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select instr(s, utf8_binary) from t5
+-- !query analysis
+Project [instr(s#x, utf8_binary#x) AS instr(s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select instr(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select instr(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [instr(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS instr(utf8_binary, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select instr(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [instr(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS instr(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 75,
+ "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5
+-- !query analysis
+Project [instr(utf8_binary#x, a) AS instr(utf8_binary, a)#x, instr(utf8_lcase#x, cast(a as string collate UTF8_LCASE)) AS instr(utf8_lcase, a)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select instr(utf8_binary, 'AaAA' collate utf8_lcase), instr(utf8_lcase, 'AAa' collate utf8_binary) from t5
+-- !query analysis
+Project [instr(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase)) AS instr(utf8_binary, collate(AaAA, utf8_lcase))#x, instr(cast(utf8_lcase#x as string), collate(AAa, utf8_binary)) AS instr(utf8_lcase, collate(AAa, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select find_in_set(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select find_in_set(s, utf8_binary) from t5
+-- !query analysis
+Project [find_in_set(s#x, utf8_binary#x) AS find_in_set(s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select find_in_set(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select find_in_set(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [find_in_set(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS find_in_set(utf8_binary, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select find_in_set(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [find_in_set(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS find_in_set(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select find_in_set(utf8_binary, 'aaAaaAaA,i̇o'), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o') from t5
+-- !query analysis
+Project [find_in_set(utf8_binary#x, aaAaaAaA,i̇o) AS find_in_set(utf8_binary, aaAaaAaA,i̇o)#x, find_in_set(utf8_lcase#x, cast(aaAaaAaA,i̇o as string collate UTF8_LCASE)) AS find_in_set(utf8_lcase, aaAaaAaA,i̇o)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select find_in_set(utf8_binary, 'aaAaaAaA,i̇o' collate utf8_lcase), find_in_set(utf8_lcase, 'aaAaaAaA,i̇o' collate utf8_binary) from t5
+-- !query analysis
+Project [find_in_set(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA,i̇o, utf8_lcase)) AS find_in_set(utf8_binary, collate(aaAaaAaA,i̇o, utf8_lcase))#x, find_in_set(cast(utf8_lcase#x as string), collate(aaAaaAaA,i̇o, utf8_binary)) AS find_in_set(utf8_lcase, collate(aaAaaAaA,i̇o, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select startswith(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select startswith(s, utf8_binary) from t5
+-- !query analysis
+Project [StartsWith(s#x, utf8_binary#x) AS startswith(s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select startswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select startswith(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [StartsWith(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS startswith(utf8_binary, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [StartsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS startswith(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 80,
+ "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5
+-- !query analysis
+Project [StartsWith(utf8_binary#x, aaAaaAaA) AS startswith(utf8_binary, aaAaaAaA)#x, StartsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS startswith(utf8_lcase, aaAaaAaA)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5
+-- !query analysis
+Project [StartsWith(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase)) AS startswith(utf8_binary, collate(aaAaaAaA, utf8_lcase))#x, StartsWith(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary)) AS startswith(utf8_lcase, collate(aaAaaAaA, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select translate(utf8_lcase, utf8_lcase, '12345') from t5
+-- !query analysis
+Project [translate(utf8_lcase#x, utf8_lcase#x, cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, utf8_lcase, 12345)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select translate(utf8_binary, utf8_lcase, '12345') from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select translate(utf8_binary, 'aBc' collate utf8_lcase, '12345' collate utf8_binary) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string collate UTF8_LCASE`, `string`"
+ }
+}
+
+
+-- !query
+select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lcase) from t5
+-- !query analysis
+Project [translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(SQL, utf8_lcase), collate(12345, utf8_lcase)) AS translate(utf8_binary, collate(SQL, utf8_lcase), collate(12345, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"utf8_binary\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 83,
+ "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5
+-- !query analysis
+Project [translate(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_lcase, aaAaaAaA, 12345)#x, translate(utf8_binary#x, aaAaaAaA, 12345) AS translate(utf8_binary, aaAaaAaA, 12345)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5
+-- !query analysis
+Project [translate(cast(utf8_lcase#x as string), collate(aBc, utf8_binary), 12345) AS translate(utf8_lcase, collate(aBc, utf8_binary), 12345)#x, translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aBc, utf8_lcase), cast(12345 as string collate UTF8_LCASE)) AS translate(utf8_binary, collate(aBc, utf8_lcase), 12345)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select replace(utf8_binary, utf8_lcase, 'abc') from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select replace(s, utf8_binary, 'abc') from t5
+-- !query analysis
+Project [replace(s#x, utf8_binary#x, abc) AS replace(s, utf8_binary, abc)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select replace(utf8_binary collate utf8_binary, s collate utf8_lcase, 'abc') from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select replace(utf8_binary, utf8_lcase collate utf8_binary, 'abc') from t5
+-- !query analysis
+Project [replace(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), abc) AS replace(utf8_binary, collate(utf8_lcase, utf8_binary), abc)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5
+-- !query analysis
+Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), abc)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 84,
+ "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')"
+ } ]
+}
+
+
+-- !query
+select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5
+-- !query analysis
+Project [replace(utf8_binary#x, aaAaaAaA, abc) AS replace(utf8_binary, aaAaaAaA, abc)#x, replace(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_lcase, aaAaaAaA, abc)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5
+-- !query analysis
+Project [replace(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase), cast(abc as string collate UTF8_LCASE)) AS replace(utf8_binary, collate(aaAaaAaA, utf8_lcase), abc)#x, replace(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary), abc) AS replace(utf8_lcase, collate(aaAaaAaA, utf8_binary), abc)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select endswith(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select endswith(s, utf8_binary) from t5
+-- !query analysis
+Project [EndsWith(s#x, utf8_binary#x) AS endswith(s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select endswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select endswith(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [EndsWith(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS endswith(utf8_binary, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [EndsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS endswith(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 78,
+ "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5
+-- !query analysis
+Project [EndsWith(utf8_binary#x, aaAaaAaA) AS endswith(utf8_binary, aaAaaAaA)#x, EndsWith(utf8_lcase#x, cast(aaAaaAaA as string collate UTF8_LCASE)) AS endswith(utf8_lcase, aaAaaAaA)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5
+-- !query analysis
+Project [EndsWith(cast(utf8_binary#x as string collate UTF8_LCASE), collate(aaAaaAaA, utf8_lcase)) AS endswith(utf8_binary, collate(aaAaaAaA, utf8_lcase))#x, EndsWith(cast(utf8_lcase#x as string), collate(aaAaaAaA, utf8_binary)) AS endswith(utf8_lcase, collate(aaAaaAaA, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select repeat(utf8_binary, 3), repeat(utf8_lcase, 2) from t5
+-- !query analysis
+Project [repeat(utf8_binary#x, 3) AS repeat(utf8_binary, 3)#x, repeat(utf8_lcase#x, 2) AS repeat(utf8_lcase, 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select repeat(utf8_binary collate utf8_lcase, 3), repeat(utf8_lcase collate utf8_binary, 2) from t5
+-- !query analysis
+Project [repeat(collate(utf8_binary#x, utf8_lcase), 3) AS repeat(collate(utf8_binary, utf8_lcase), 3)#x, repeat(collate(utf8_lcase#x, utf8_binary), 2) AS repeat(collate(utf8_lcase, utf8_binary), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select ascii(utf8_binary), ascii(utf8_lcase) from t5
+-- !query analysis
+Project [ascii(utf8_binary#x) AS ascii(utf8_binary)#x, ascii(utf8_lcase#x) AS ascii(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select ascii(utf8_binary collate utf8_lcase), ascii(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [ascii(collate(utf8_binary#x, utf8_lcase)) AS ascii(collate(utf8_binary, utf8_lcase))#x, ascii(collate(utf8_lcase#x, utf8_binary)) AS ascii(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select unbase64(utf8_binary), unbase64(utf8_lcase) from t10
+-- !query analysis
+Project [unbase64(utf8_binary#x, false) AS unbase64(utf8_binary)#x, unbase64(utf8_lcase#x, false) AS unbase64(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t10
+ +- Relation spark_catalog.default.t10[utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select unbase64(utf8_binary collate utf8_lcase), unbase64(utf8_lcase collate utf8_binary) from t10
+-- !query analysis
+Project [unbase64(collate(utf8_binary#x, utf8_lcase), false) AS unbase64(collate(utf8_binary, utf8_lcase))#x, unbase64(collate(utf8_lcase#x, utf8_binary), false) AS unbase64(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t10
+ +- Relation spark_catalog.default.t10[utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select chr(ascii) from t6
+-- !query analysis
+Project [chr(ascii#xL) AS chr(ascii)#x]
++- SubqueryAlias spark_catalog.default.t6
+ +- Relation spark_catalog.default.t6[ascii#xL] parquet
+
+
+-- !query
+select base64(utf8_binary), base64(utf8_lcase) from t5
+-- !query analysis
+Project [base64(cast(utf8_binary#x as binary)) AS base64(utf8_binary)#x, base64(cast(utf8_lcase#x as binary)) AS base64(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select base64(utf8_binary collate utf8_lcase), base64(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [base64(cast(collate(utf8_binary#x, utf8_lcase) as binary)) AS base64(collate(utf8_binary, utf8_lcase))#x, base64(cast(collate(utf8_lcase#x, utf8_binary) as binary)) AS base64(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select decode(encode(utf8_binary, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase, 'utf-8'), 'utf-8') from t5
+-- !query analysis
+Project [decode(encode(utf8_binary#x, utf-8), utf-8) AS decode(encode(utf8_binary, utf-8), utf-8)#x, decode(encode(utf8_lcase#x, utf-8), utf-8) AS decode(encode(utf8_lcase, utf-8), utf-8)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select decode(encode(utf8_binary collate utf8_lcase, 'utf-8'), 'utf-8'), decode(encode(utf8_lcase collate utf8_binary, 'utf-8'), 'utf-8') from t5
+-- !query analysis
+Project [decode(encode(collate(utf8_binary#x, utf8_lcase), utf-8), utf-8) AS decode(encode(collate(utf8_binary, utf8_lcase), utf-8), utf-8)#x, decode(encode(collate(utf8_lcase#x, utf8_binary), utf-8), utf-8) AS decode(encode(collate(utf8_lcase, utf8_binary), utf-8), utf-8)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select format_number(ascii, '###.###') from t7
+-- !query analysis
+Project [format_number(ascii#x, ###.###) AS format_number(ascii, ###.###)#x]
++- SubqueryAlias spark_catalog.default.t7
+ +- Relation spark_catalog.default.t7[ascii#x] parquet
+
+
+-- !query
+select format_number(ascii, '###.###' collate utf8_lcase) from t7
+-- !query analysis
+Project [format_number(ascii#x, collate(###.###, utf8_lcase)) AS format_number(ascii, collate(###.###, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t7
+ +- Relation spark_catalog.default.t7[ascii#x] parquet
+
+
+-- !query
+select encode(utf8_binary, 'utf-8'), encode(utf8_lcase, 'utf-8') from t5
+-- !query analysis
+Project [encode(utf8_binary#x, utf-8) AS encode(utf8_binary, utf-8)#x, encode(utf8_lcase#x, utf-8) AS encode(utf8_lcase, utf-8)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select encode(utf8_binary collate utf8_lcase, 'utf-8'), encode(utf8_lcase collate utf8_binary, 'utf-8') from t5
+-- !query analysis
+Project [encode(collate(utf8_binary#x, utf8_lcase), utf-8) AS encode(collate(utf8_binary, utf8_lcase), utf-8)#x, encode(collate(utf8_lcase#x, utf8_binary), utf-8) AS encode(collate(utf8_lcase, utf8_binary), utf-8)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select to_binary(utf8_binary, 'utf-8'), to_binary(utf8_lcase, 'utf-8') from t5
+-- !query analysis
+Project [to_binary(utf8_binary#x, Some(utf-8), false) AS to_binary(utf8_binary, utf-8)#x, to_binary(utf8_lcase#x, Some(utf-8), false) AS to_binary(utf8_lcase, utf-8)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select to_binary(utf8_binary collate utf8_lcase, 'utf-8'), to_binary(utf8_lcase collate utf8_binary, 'utf-8') from t5
+-- !query analysis
+Project [to_binary(collate(utf8_binary#x, utf8_lcase), Some(utf-8), false) AS to_binary(collate(utf8_binary, utf8_lcase), utf-8)#x, to_binary(collate(utf8_lcase#x, utf8_binary), Some(utf-8), false) AS to_binary(collate(utf8_lcase, utf8_binary), utf-8)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select sentences(utf8_binary), sentences(utf8_lcase) from t5
+-- !query analysis
+Project [sentences(utf8_binary#x, , ) AS sentences(utf8_binary, , )#x, sentences(utf8_lcase#x, , ) AS sentences(utf8_lcase, , )#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select sentences(utf8_binary collate utf8_lcase), sentences(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [sentences(collate(utf8_binary#x, utf8_lcase), , ) AS sentences(collate(utf8_binary, utf8_lcase), , )#x, sentences(collate(utf8_lcase#x, utf8_binary), , ) AS sentences(collate(utf8_lcase, utf8_binary), , )#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select upper(utf8_binary), upper(utf8_lcase) from t5
+-- !query analysis
+Project [upper(utf8_binary#x) AS upper(utf8_binary)#x, upper(utf8_lcase#x) AS upper(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select upper(utf8_binary collate utf8_lcase), upper(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [upper(collate(utf8_binary#x, utf8_lcase)) AS upper(collate(utf8_binary, utf8_lcase))#x, upper(collate(utf8_lcase#x, utf8_binary)) AS upper(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select lower(utf8_binary), lower(utf8_lcase) from t5
+-- !query analysis
+Project [lower(utf8_binary#x) AS lower(utf8_binary)#x, lower(utf8_lcase#x) AS lower(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select lower(utf8_binary collate utf8_lcase), lower(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [lower(collate(utf8_binary#x, utf8_lcase)) AS lower(collate(utf8_binary, utf8_lcase))#x, lower(collate(utf8_lcase#x, utf8_binary)) AS lower(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select initcap(utf8_binary), initcap(utf8_lcase) from t5
+-- !query analysis
+Project [initcap(utf8_binary#x) AS initcap(utf8_binary)#x, initcap(utf8_lcase#x) AS initcap(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select initcap(utf8_binary collate utf8_lcase), initcap(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [initcap(collate(utf8_binary#x, utf8_lcase)) AS initcap(collate(utf8_binary, utf8_lcase))#x, initcap(collate(utf8_lcase#x, utf8_binary)) AS initcap(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select overlay(utf8_binary, utf8_lcase, 2) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select overlay(s, utf8_binary,1) from t5
+-- !query analysis
+Project [overlay(s#x, utf8_binary#x, 1, -1) AS overlay(s, utf8_binary, 1, -1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select overlay(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select overlay(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5
+-- !query analysis
+Project [overlay(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 2, -1) AS overlay(utf8_binary, collate(utf8_lcase, utf8_binary), 2, -1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select overlay(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5
+-- !query analysis
+Project [overlay(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 2, -1) AS overlay(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 2, -1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select overlay(utf8_binary, 'a', 2), overlay(utf8_lcase, 'a', 2) from t5
+-- !query analysis
+Project [overlay(utf8_binary#x, a, 2, -1) AS overlay(utf8_binary, a, 2, -1)#x, overlay(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 2, -1) AS overlay(utf8_lcase, a, 2, -1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select overlay(utf8_binary, 'AaAA' collate utf8_lcase, 2), overlay(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5
+-- !query analysis
+Project [overlay(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), 2, -1) AS overlay(utf8_binary, collate(AaAA, utf8_lcase), 2, -1)#x, overlay(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), 2, -1) AS overlay(utf8_lcase, collate(AAa, utf8_binary), 2, -1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select format_string(format, utf8_binary, utf8_lcase) from t8
+-- !query analysis
+Project [format_string(format#x, utf8_binary#x, utf8_lcase#x) AS format_string(format, utf8_binary, utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t8
+ +- Relation spark_catalog.default.t8[format#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select format_string(format collate utf8_lcase, utf8_lcase, utf8_binary collate utf8_lcase, 3), format_string(format, utf8_lcase collate utf8_binary, utf8_binary) from t8
+-- !query analysis
+Project [format_string(collate(format#x, utf8_lcase), utf8_lcase#x, collate(utf8_binary#x, utf8_lcase), 3) AS format_string(collate(format, utf8_lcase), utf8_lcase, collate(utf8_binary, utf8_lcase), 3)#x, format_string(format#x, collate(utf8_lcase#x, utf8_binary), utf8_binary#x) AS format_string(format, collate(utf8_lcase, utf8_binary), utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t8
+ +- Relation spark_catalog.default.t8[format#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select format_string(format, utf8_binary, utf8_lcase) from t8
+-- !query analysis
+Project [format_string(format#x, utf8_binary#x, utf8_lcase#x) AS format_string(format, utf8_binary, utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t8
+ +- Relation spark_catalog.default.t8[format#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select soundex(utf8_binary), soundex(utf8_lcase) from t5
+-- !query analysis
+Project [soundex(utf8_binary#x) AS soundex(utf8_binary)#x, soundex(utf8_lcase#x) AS soundex(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select soundex(utf8_binary collate utf8_lcase), soundex(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [soundex(collate(utf8_binary#x, utf8_lcase)) AS soundex(collate(utf8_binary, utf8_lcase))#x, soundex(collate(utf8_lcase#x, utf8_binary)) AS soundex(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select length(utf8_binary), length(utf8_lcase) from t5
+-- !query analysis
+Project [length(utf8_binary#x) AS length(utf8_binary)#x, length(utf8_lcase#x) AS length(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select length(utf8_binary collate utf8_lcase), length(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [length(collate(utf8_binary#x, utf8_lcase)) AS length(collate(utf8_binary, utf8_lcase))#x, length(collate(utf8_lcase#x, utf8_binary)) AS length(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select bit_length(utf8_binary), bit_length(utf8_lcase) from t5
+-- !query analysis
+Project [bit_length(utf8_binary#x) AS bit_length(utf8_binary)#x, bit_length(utf8_lcase#x) AS bit_length(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select bit_length(utf8_binary collate utf8_lcase), bit_length(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [bit_length(collate(utf8_binary#x, utf8_lcase)) AS bit_length(collate(utf8_binary, utf8_lcase))#x, bit_length(collate(utf8_lcase#x, utf8_binary)) AS bit_length(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select octet_length(utf8_binary), octet_length(utf8_lcase) from t5
+-- !query analysis
+Project [octet_length(utf8_binary#x) AS octet_length(utf8_binary)#x, octet_length(utf8_lcase#x) AS octet_length(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select octet_length(utf8_binary collate utf8_lcase), octet_length(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [octet_length(collate(utf8_binary#x, utf8_lcase)) AS octet_length(collate(utf8_binary, utf8_lcase))#x, octet_length(collate(utf8_lcase#x, utf8_binary)) AS octet_length(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select luhn_check(num) from t9
+-- !query analysis
+Project [luhn_check(cast(num#xL as string)) AS luhn_check(num)#x]
++- SubqueryAlias spark_catalog.default.t9
+ +- Relation spark_catalog.default.t9[num#xL] parquet
+
+
+-- !query
+select levenshtein(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select levenshtein(s, utf8_binary) from t5
+-- !query analysis
+Project [levenshtein(s#x, utf8_binary#x, None) AS levenshtein(s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select levenshtein(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select levenshtein(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [levenshtein(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), None) AS levenshtein(utf8_binary, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select levenshtein(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [levenshtein(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), None) AS levenshtein(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select levenshtein(utf8_binary, 'a'), levenshtein(utf8_lcase, 'a') from t5
+-- !query analysis
+Project [levenshtein(utf8_binary#x, a, None) AS levenshtein(utf8_binary, a)#x, levenshtein(utf8_lcase#x, cast(a as string collate UTF8_LCASE), None) AS levenshtein(utf8_lcase, a)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select levenshtein(utf8_binary, 'AaAA' collate utf8_lcase, 3), levenshtein(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5
+-- !query analysis
+Project [levenshtein(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), Some(3)) AS levenshtein(utf8_binary, collate(AaAA, utf8_lcase), 3)#x, levenshtein(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), Some(4)) AS levenshtein(utf8_lcase, collate(AAa, utf8_binary), 4)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select is_valid_utf8(utf8_binary), is_valid_utf8(utf8_lcase) from t5
+-- !query analysis
+Project [is_valid_utf8(utf8_binary#x) AS is_valid_utf8(utf8_binary)#x, is_valid_utf8(utf8_lcase#x) AS is_valid_utf8(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select is_valid_utf8(utf8_binary collate utf8_lcase), is_valid_utf8(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [is_valid_utf8(collate(utf8_binary#x, utf8_lcase)) AS is_valid_utf8(collate(utf8_binary, utf8_lcase))#x, is_valid_utf8(collate(utf8_lcase#x, utf8_binary)) AS is_valid_utf8(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select make_valid_utf8(utf8_binary), make_valid_utf8(utf8_lcase) from t5
+-- !query analysis
+Project [make_valid_utf8(utf8_binary#x) AS make_valid_utf8(utf8_binary)#x, make_valid_utf8(utf8_lcase#x) AS make_valid_utf8(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select make_valid_utf8(utf8_binary collate utf8_lcase), make_valid_utf8(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [make_valid_utf8(collate(utf8_binary#x, utf8_lcase)) AS make_valid_utf8(collate(utf8_binary, utf8_lcase))#x, make_valid_utf8(collate(utf8_lcase#x, utf8_binary)) AS make_valid_utf8(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select validate_utf8(utf8_binary), validate_utf8(utf8_lcase) from t5
+-- !query analysis
+Project [validate_utf8(utf8_binary#x) AS validate_utf8(utf8_binary)#x, validate_utf8(utf8_lcase#x) AS validate_utf8(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select validate_utf8(utf8_binary collate utf8_lcase), validate_utf8(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [validate_utf8(collate(utf8_binary#x, utf8_lcase)) AS validate_utf8(collate(utf8_binary, utf8_lcase))#x, validate_utf8(collate(utf8_lcase#x, utf8_binary)) AS validate_utf8(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select try_validate_utf8(utf8_binary), try_validate_utf8(utf8_lcase) from t5
+-- !query analysis
+Project [try_validate_utf8(utf8_binary#x) AS try_validate_utf8(utf8_binary)#x, try_validate_utf8(utf8_lcase#x) AS try_validate_utf8(utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select try_validate_utf8(utf8_binary collate utf8_lcase), try_validate_utf8(utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [try_validate_utf8(collate(utf8_binary#x, utf8_lcase)) AS try_validate_utf8(collate(utf8_binary, utf8_lcase))#x, try_validate_utf8(collate(utf8_lcase#x, utf8_binary)) AS try_validate_utf8(collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select substr(utf8_binary, 2, 2), substr(utf8_lcase, 2, 2) from t5
+-- !query analysis
+Project [substr(utf8_binary#x, 2, 2) AS substr(utf8_binary, 2, 2)#x, substr(utf8_lcase#x, 2, 2) AS substr(utf8_lcase, 2, 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select substr(utf8_binary collate utf8_lcase, 2, 2), substr(utf8_lcase collate utf8_binary, 2, 2) from t5
+-- !query analysis
+Project [substr(collate(utf8_binary#x, utf8_lcase), 2, 2) AS substr(collate(utf8_binary, utf8_lcase), 2, 2)#x, substr(collate(utf8_lcase#x, utf8_binary), 2, 2) AS substr(collate(utf8_lcase, utf8_binary), 2, 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select right(utf8_binary, 2), right(utf8_lcase, 2) from t5
+-- !query analysis
+Project [right(utf8_binary#x, 2) AS right(utf8_binary, 2)#x, right(utf8_lcase#x, 2) AS right(utf8_lcase, 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select right(utf8_binary collate utf8_lcase, 2), right(utf8_lcase collate utf8_binary, 2) from t5
+-- !query analysis
+Project [right(collate(utf8_binary#x, utf8_lcase), 2) AS right(collate(utf8_binary, utf8_lcase), 2)#x, right(collate(utf8_lcase#x, utf8_binary), 2) AS right(collate(utf8_lcase, utf8_binary), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select left(utf8_binary, '2' collate utf8_lcase), left(utf8_lcase, 2) from t5
+-- !query analysis
+Project [left(utf8_binary#x, cast(collate(2, utf8_lcase) as int)) AS left(utf8_binary, collate(2, utf8_lcase))#x, left(utf8_lcase#x, 2) AS left(utf8_lcase, 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select left(utf8_binary collate utf8_lcase, 2), left(utf8_lcase collate utf8_binary, 2) from t5
+-- !query analysis
+Project [left(collate(utf8_binary#x, utf8_lcase), 2) AS left(collate(utf8_binary, utf8_lcase), 2)#x, left(collate(utf8_lcase#x, utf8_binary), 2) AS left(collate(utf8_lcase, utf8_binary), 2)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select rpad(utf8_binary, 8, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select rpad(s, 8, utf8_binary) from t5
+-- !query analysis
+Project [rpad(s#x, 8, utf8_binary#x) AS rpad(s, 8, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select rpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select rpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [rpad(utf8_binary#x, 8, collate(utf8_lcase#x, utf8_binary)) AS rpad(utf8_binary, 8, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select rpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [rpad(collate(utf8_binary#x, utf8_lcase), 8, collate(utf8_lcase#x, utf8_lcase)) AS rpad(collate(utf8_binary, utf8_lcase), 8, collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select rpad(utf8_binary, 8, 'a'), rpad(utf8_lcase, 8, 'a') from t5
+-- !query analysis
+Project [rpad(utf8_binary#x, 8, a) AS rpad(utf8_binary, 8, a)#x, rpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS rpad(utf8_lcase, 8, a)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select rpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), rpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5
+-- !query analysis
+Project [rpad(cast(utf8_binary#x as string collate UTF8_LCASE), 8, collate(AaAA, utf8_lcase)) AS rpad(utf8_binary, 8, collate(AaAA, utf8_lcase))#x, rpad(cast(utf8_lcase#x as string), 8, collate(AAa, utf8_binary)) AS rpad(utf8_lcase, 8, collate(AAa, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select lpad(utf8_binary, 8, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select lpad(s, 8, utf8_binary) from t5
+-- !query analysis
+Project [lpad(s#x, 8, utf8_binary#x) AS lpad(s, 8, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select lpad(utf8_binary collate utf8_binary, 8, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select lpad(utf8_binary, 8, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [lpad(utf8_binary#x, 8, collate(utf8_lcase#x, utf8_binary)) AS lpad(utf8_binary, 8, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select lpad(utf8_binary collate utf8_lcase, 8, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [lpad(collate(utf8_binary#x, utf8_lcase), 8, collate(utf8_lcase#x, utf8_lcase)) AS lpad(collate(utf8_binary, utf8_lcase), 8, collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select lpad(utf8_binary, 8, 'a'), lpad(utf8_lcase, 8, 'a') from t5
+-- !query analysis
+Project [lpad(utf8_binary#x, 8, a) AS lpad(utf8_binary, 8, a)#x, lpad(utf8_lcase#x, 8, cast(a as string collate UTF8_LCASE)) AS lpad(utf8_lcase, 8, a)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select lpad(utf8_binary, 8, 'AaAA' collate utf8_lcase), lpad(utf8_lcase, 8, 'AAa' collate utf8_binary) from t5
+-- !query analysis
+Project [lpad(cast(utf8_binary#x as string collate UTF8_LCASE), 8, collate(AaAA, utf8_lcase)) AS lpad(utf8_binary, 8, collate(AaAA, utf8_lcase))#x, lpad(cast(utf8_lcase#x as string), 8, collate(AAa, utf8_binary)) AS lpad(utf8_lcase, 8, collate(AAa, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select locate(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select locate(s, utf8_binary) from t5
+-- !query analysis
+Project [locate(s#x, utf8_binary#x, 1) AS locate(s, utf8_binary, 1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select locate(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select locate(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [locate(utf8_binary#x, collate(utf8_lcase#x, utf8_binary), 1) AS locate(utf8_binary, collate(utf8_lcase, utf8_binary), 1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) from t5
+-- !query analysis
+Project [locate(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase), 3) AS locate(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase), 3)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 79,
+ "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)"
+ } ]
+}
+
+
+-- !query
+select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5
+-- !query analysis
+Project [locate(utf8_binary#x, a, 1) AS locate(utf8_binary, a, 1)#x, locate(utf8_lcase#x, cast(a as string collate UTF8_LCASE), 1) AS locate(utf8_lcase, a, 1)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5
+-- !query analysis
+Project [locate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(AaAA, utf8_lcase), 4) AS locate(utf8_binary, collate(AaAA, utf8_lcase), 4)#x, locate(cast(utf8_lcase#x as string), collate(AAa, utf8_binary), 4) AS locate(utf8_lcase, collate(AAa, utf8_binary), 4)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select TRIM(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select TRIM(s, utf8_binary) from t5
+-- !query analysis
+Project [trim(utf8_binary#x, Some(s#x)) AS TRIM(BOTH s FROM utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string collate UTF8_LCASE`, `string`"
+ }
+}
+
+
+-- !query
+select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [trim(collate(utf8_lcase#x, utf8_binary), Some(utf8_binary#x)) AS TRIM(BOTH utf8_binary FROM collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [trim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf8_lcase))) AS TRIM(BOTH collate(utf8_binary, utf8_lcase) FROM collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 74,
+ "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5
+-- !query analysis
+Project [trim(utf8_binary#x, Some(ABc)) AS TRIM(BOTH ABc FROM utf8_binary)#x, trim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(BOTH ABc FROM utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5
+-- !query analysis
+Project [trim(cast(utf8_binary#x as string collate UTF8_LCASE), Some(collate(ABc, utf8_lcase))) AS TRIM(BOTH collate(ABc, utf8_lcase) FROM utf8_binary)#x, trim(cast(utf8_lcase#x as string), Some(collate(AAa, utf8_binary))) AS TRIM(BOTH collate(AAa, utf8_binary) FROM utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select BTRIM(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select BTRIM(s, utf8_binary) from t5
+-- !query analysis
+Project [btrim(s#x, utf8_binary#x) AS btrim(s, utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string`, `string collate UTF8_LCASE`"
+ }
+}
+
+
+-- !query
+select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [btrim(utf8_binary#x, collate(utf8_lcase#x, utf8_binary)) AS btrim(utf8_binary, collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [btrim(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lcase)) AS btrim(collate(utf8_binary, utf8_lcase), collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_binary, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 75,
+ "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5
+-- !query analysis
+Project [btrim(ABc, utf8_binary#x) AS btrim(ABc, utf8_binary)#x, btrim(ABc, utf8_lcase#x) AS btrim(ABc, utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5
+-- !query analysis
+Project [btrim(collate(ABc, utf8_lcase), utf8_binary#x) AS btrim(collate(ABc, utf8_lcase), utf8_binary)#x, btrim(collate(AAa, utf8_binary), utf8_lcase#x) AS btrim(collate(AAa, utf8_binary), utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select LTRIM(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select LTRIM(s, utf8_binary) from t5
+-- !query analysis
+Project [ltrim(utf8_binary#x, Some(s#x)) AS TRIM(LEADING s FROM utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string collate UTF8_LCASE`, `string`"
+ }
+}
+
+
+-- !query
+select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [ltrim(collate(utf8_lcase#x, utf8_binary), Some(utf8_binary#x)) AS TRIM(LEADING utf8_binary FROM collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [ltrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf8_lcase))) AS TRIM(LEADING collate(utf8_binary, utf8_lcase) FROM collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 75,
+ "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5
+-- !query analysis
+Project [ltrim(utf8_binary#x, Some(ABc)) AS TRIM(LEADING ABc FROM utf8_binary)#x, ltrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(LEADING ABc FROM utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5
+-- !query analysis
+Project [ltrim(cast(utf8_binary#x as string collate UTF8_LCASE), Some(collate(ABc, utf8_lcase))) AS TRIM(LEADING collate(ABc, utf8_lcase) FROM utf8_binary)#x, ltrim(cast(utf8_lcase#x as string), Some(collate(AAa, utf8_binary))) AS TRIM(LEADING collate(AAa, utf8_binary) FROM utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select RTRIM(utf8_binary, utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.IMPLICIT",
+ "sqlState" : "42P21"
+}
+
+
+-- !query
+select RTRIM(s, utf8_binary) from t5
+-- !query analysis
+Project [rtrim(utf8_binary#x, Some(s#x)) AS TRIM(TRAILING s FROM utf8_binary)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "COLLATION_MISMATCH.EXPLICIT",
+ "sqlState" : "42P21",
+ "messageParameters" : {
+ "explicitTypes" : "`string collate UTF8_LCASE`, `string`"
+ }
+}
+
+
+-- !query
+select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5
+-- !query analysis
+Project [rtrim(collate(utf8_lcase#x, utf8_binary), Some(utf8_binary#x)) AS TRIM(TRAILING utf8_binary FROM collate(utf8_lcase, utf8_binary))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5
+-- !query analysis
+Project [rtrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf8_lcase))) AS TRIM(TRAILING collate(utf8_binary, utf8_lcase) FROM collate(utf8_lcase, utf8_lcase))#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"",
+ "inputType" : "\"STRING COLLATE UNICODE_AI\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"STRING\"",
+ "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 75,
+ "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)"
+ } ]
+}
+
+
+-- !query
+select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5
+-- !query analysis
+Project [rtrim(utf8_binary#x, Some(ABc)) AS TRIM(TRAILING ABc FROM utf8_binary)#x, rtrim(utf8_lcase#x, Some(cast(ABc as string collate UTF8_LCASE))) AS TRIM(TRAILING ABc FROM utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5
+-- !query analysis
+Project [rtrim(cast(utf8_binary#x as string collate UTF8_LCASE), Some(collate(ABc, utf8_lcase))) AS TRIM(TRAILING collate(ABc, utf8_lcase) FROM utf8_binary)#x, rtrim(cast(utf8_lcase#x as string), Some(collate(AAa, utf8_binary))) AS TRIM(TRAILING collate(AAa, utf8_binary) FROM utf8_lcase)#x]
++- SubqueryAlias spark_catalog.default.t5
+ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet
+
+
+-- !query
+drop table t5
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t5
+
+
+-- !query
+drop table t6
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t6
+
+
+-- !query
+drop table t7
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t7
+
+
+-- !query
+drop table t8
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t8
+
+
+-- !query
+drop table t9
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t9
+
+
+-- !query
+drop table t10
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t10
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out
index 48137e06467e8..88c7d7b4e7d72 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out
@@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles)
-- !query
select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy'))
-- !query analysis
-Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x]
+Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x]
+- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out
index 1e49f4df8267a..4221db822d024 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out
@@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles)
-- !query
select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy'))
-- !query analysis
-Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x]
+Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x]
+- OneRowRelation
@@ -1833,7 +1833,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An
-- !query
select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy'))
-- !query analysis
-Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x]
+Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x]
+- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out
index 78bf1ccb1678c..ce510527c8781 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out
@@ -471,7 +471,6 @@ org.apache.spark.SparkNumberFormatException
"errorClass" : "CAST_INVALID_INPUT",
"sqlState" : "22018",
"messageParameters" : {
- "ansiConfig" : "\"spark.sql.ansi.enabled\"",
"expression" : "'invalid_cast_error_expected'",
"sourceType" : "\"STRING\"",
"targetType" : "\"INT\""
@@ -662,7 +661,6 @@ org.apache.spark.SparkNumberFormatException
"errorClass" : "CAST_INVALID_INPUT",
"sqlState" : "22018",
"messageParameters" : {
- "ansiConfig" : "\"spark.sql.ansi.enabled\"",
"expression" : "'name1'",
"sourceType" : "\"STRING\"",
"targetType" : "\"INT\""
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out
index 290e55052931d..efa149509751d 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out
@@ -981,9 +981,13 @@ select interval '20 15:40:32.99899999' day to hour
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR` when cast to interval day to hour: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "20 15:40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]d h`, `INTERVAL [+|-]'[+|-]d h' DAY TO HOUR`",
+ "typeName" : "interval day to hour"
},
"queryContext" : [ {
"objectType" : "",
@@ -1000,9 +1004,13 @@ select interval '20 15:40:32.99899999' day to minute
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE` when cast to interval day to minute: 20 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "20 15:40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]d h:m`, `INTERVAL [+|-]'[+|-]d h:m' DAY TO MINUTE`",
+ "typeName" : "interval day to minute"
},
"queryContext" : [ {
"objectType" : "",
@@ -1019,9 +1027,13 @@ select interval '15:40:32.99899999' hour to minute
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE` when cast to interval hour to minute: 15:40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "15:40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]h:m`, `INTERVAL [+|-]'[+|-]h:m' HOUR TO MINUTE`",
+ "typeName" : "interval hour to minute"
},
"queryContext" : [ {
"objectType" : "",
@@ -1038,9 +1050,13 @@ select interval '15:40.99899999' hour to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "15:40.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`",
+ "typeName" : "interval hour to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -1057,9 +1073,13 @@ select interval '15:40' hour to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND` when cast to interval hour to second: 15:40, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "15:40",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]h:m:s.n`, `INTERVAL [+|-]'[+|-]h:m:s.n' HOUR TO SECOND`",
+ "typeName" : "interval hour to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -1076,9 +1096,13 @@ select interval '20 40:32.99899999' minute to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND` when cast to interval minute to second: 20 40:32.99899999, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "20 40:32.99899999",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]m:s.n`, `INTERVAL [+|-]'[+|-]m:s.n' MINUTE TO SECOND`",
+ "typeName" : "interval minute to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -1460,9 +1484,11 @@ SELECT INTERVAL '178956970-8' YEAR TO MONTH
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.INTERVAL_PARSING",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Error parsing interval year-month string: integer overflow"
+ "input" : "178956970-8",
+ "interval" : "year-month"
},
"queryContext" : [ {
"objectType" : "",
@@ -1909,9 +1935,13 @@ select interval '-\t2-2\t' year to month
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match year-month format of `[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH` when cast to interval year to month: -\t2-2\t"
+ "input" : "-\t2-2\t",
+ "intervalStr" : "year-month",
+ "supportedFormat" : "`[+|-]y-m`, `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`",
+ "typeName" : "interval year to month"
},
"queryContext" : [ {
"objectType" : "",
@@ -1935,9 +1965,13 @@ select interval '\n-\t10\t 12:34:46.789\t' day to second
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_0063",
+ "errorClass" : "INVALID_INTERVAL_FORMAT.UNMATCHED_FORMAT_STRING_WITH_NOTICE",
+ "sqlState" : "22006",
"messageParameters" : {
- "msg" : "Interval string does not match day-time format of `[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND` when cast to interval day to second: \n-\t10\t 12:34:46.789\t, set spark.sql.legacy.fromDayTimeString.enabled to true to restore the behavior before Spark 3.0."
+ "input" : "\n-\t10\t 12:34:46.789\t",
+ "intervalStr" : "day-time",
+ "supportedFormat" : "`[+|-]d h:m:s.n`, `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`",
+ "typeName" : "interval day to second"
},
"queryContext" : [ {
"objectType" : "",
@@ -2074,7 +2108,7 @@ SELECT
to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)),
from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month')
-- !query analysis
-Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x]
+Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x]
+- OneRowRelation
@@ -2085,7 +2119,7 @@ SELECT
to_json(map('a', interval 100 day 130 minute)),
from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute')
-- !query analysis
-Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x]
+Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x]
+- OneRowRelation
@@ -2096,7 +2130,7 @@ SELECT
to_json(map('a', interval 32 year 10 month)),
from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month')
-- !query analysis
-Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x]
+Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x]
+- OneRowRelation
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out
index e81ee769f57d6..5bf893605423c 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out
@@ -3017,6 +3017,53 @@ Project [c1#x, c2#x, t#x]
+- LocalRelation [col1#x, col2#x]
+-- !query
+select 1
+from t1 as t_outer
+left join
+ lateral(
+ select b1,b2
+ from
+ (
+ select
+ t2.c1 as b1,
+ 1 as b2
+ from t2
+ union
+ select t_outer.c1 as b1,
+ null as b2
+ ) as t_inner
+ where (t_inner.b1 < t_outer.c2 or t_inner.b1 is null)
+ and t_inner.b1 = t_outer.c1
+ order by t_inner.b1,t_inner.b2 desc limit 1
+ ) as lateral_table
+-- !query analysis
+Project [1 AS 1#x]
++- LateralJoin lateral-subquery#x [c2#x && c1#x && c1#x], LeftOuter
+ : +- SubqueryAlias lateral_table
+ : +- GlobalLimit 1
+ : +- LocalLimit 1
+ : +- Sort [b1#x ASC NULLS FIRST, b2#x DESC NULLS LAST], true
+ : +- Project [b1#x, b2#x]
+ : +- Filter (((b1#x < outer(c2#x)) OR isnull(b1#x)) AND (b1#x = outer(c1#x)))
+ : +- SubqueryAlias t_inner
+ : +- Distinct
+ : +- Union false, false
+ : :- Project [c1#x AS b1#x, 1 AS b2#x]
+ : : +- SubqueryAlias spark_catalog.default.t2
+ : : +- View (`spark_catalog`.`default`.`t2`, [c1#x, c2#x])
+ : : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+ : : +- LocalRelation [col1#x, col2#x]
+ : +- Project [b1#x, cast(b2#x as int) AS b2#x]
+ : +- Project [outer(c1#x) AS b1#x, null AS b2#x]
+ : +- OneRowRelation
+ +- SubqueryAlias t_outer
+ +- SubqueryAlias spark_catalog.default.t1
+ +- View (`spark_catalog`.`default`.`t1`, [c1#x, c2#x])
+ +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
-- !query
DROP VIEW t1
-- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out
index 0d7c6b2056231..fef9d0c5b6250 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out
@@ -118,14 +118,14 @@ org.apache.spark.sql.AnalysisException
-- !query
select from_json('{"a":1}', 'a INT')
-- !query analysis
-Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles)) AS from_json({"a":1})#x]
+Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles), false) AS from_json({"a":1})#x]
+- OneRowRelation
-- !query
select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy'))
-- !query analysis
-Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles)) AS from_json({"time":"26/08/2015"})#x]
+Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles), false) AS from_json({"time":"26/08/2015"})#x]
+- OneRowRelation
@@ -279,14 +279,14 @@ DropTempViewCommand jsonTable
-- !query
select from_json('{"a":1, "b":2}', 'map')
-- !query analysis
-Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles)) AS entries#x]
+Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles), false) AS entries#x]
+- OneRowRelation
-- !query
select from_json('{"a":1, "b":"2"}', 'struct')
-- !query analysis
-Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles)) AS from_json({"a":1, "b":"2"})#x]
+Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles), false) AS from_json({"a":1, "b":"2"})#x]
+- OneRowRelation
@@ -300,70 +300,70 @@ Project [schema_of_json({"c1":0, "c2":[1]}) AS schema_of_json({"c1":0, "c2":[1]}
-- !query
select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}'))
-- !query analysis
-Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles)) AS from_json({"c1":[1, 2, 3]})#x]
+Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles), false) AS from_json({"c1":[1, 2, 3]})#x]
+- OneRowRelation
-- !query
select from_json('[1, 2, 3]', 'array')
-- !query analysis
-Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles)) AS from_json([1, 2, 3])#x]
+Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles), false) AS from_json([1, 2, 3])#x]
+- OneRowRelation
-- !query
select from_json('[1, "2", 3]', 'array')
-- !query analysis
-Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles)) AS from_json([1, "2", 3])#x]
+Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles), false) AS from_json([1, "2", 3])#x]
+- OneRowRelation
-- !query
select from_json('[1, 2, null]', 'array')
-- !query analysis
-Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles)) AS from_json([1, 2, null])#x]
+Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles), false) AS from_json([1, 2, null])#x]
+- OneRowRelation
-- !query
select from_json('[{"a": 1}, {"a":2}]', 'array>')
-- !query analysis
-Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles)) AS from_json([{"a": 1}, {"a":2}])#x]
+Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, {"a":2}])#x]
+- OneRowRelation
-- !query
select from_json('{"a": 1}', 'array>')
-- !query analysis
-Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x]
+Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles), false) AS from_json({"a": 1})#x]
+- OneRowRelation
-- !query
select from_json('[null, {"a":2}]', 'array>')
-- !query analysis
-Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles)) AS from_json([null, {"a":2}])#x]
+Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles), false) AS from_json([null, {"a":2}])#x]
+- OneRowRelation
-- !query
select from_json('[{"a": 1}, {"b":2}]', 'array